acm.model#
- acm.model.train.TrainFCN(lhc_y: ndarray, lhc_x: ndarray, lhc_x_names: list, covariance_matrix: ndarray, stat_name: str, model_dir: str, n_test: int | list, learning_rate: float, n_hidden: list, dropout_rate: float, weight_decay: float, act_fn: str = 'learned_sigmoid', loss: str = 'mae', max_epochs: int = 5000, log_dir: str = None, seed: int = None, transform=None) float[source]
Train a Fully Connected Neural Network (FCN) emulator for the given statistic, with the given hyperparameters. This function expects the LHC data and the covariance matrix to be in the same format as the one used in the ACM pipeline.
- Parameters:
lhc_y (np.ndarray) – LHC y data for the statistic to train on (outputs)
lhc_x (np.ndarray) – LHC x data for the statistic to train on (inputs).
lhc_x_names (list) – List of the names of the input parameters.
covariance_matrix (np.ndarray) – Covariance matrix for the statistic to train on.
stat_name (str) – Statistic to train on.
model_dir (str, optional) – Directory to save the model.
n_test (int|list) – Number of training samples to select from the LHC data. Must be smaller than the total number of samples. If a list is provided, those indexes are used to select the test samples (excluded from the training set). If an integer is provided, the first n_test samples are used for testing. Set to 0 to use all the samples for training.
learning_rate (float) – Learning rate for the optimizer.
n_hidden (list) – List of integers, number of neurons in each hidden layer.
dropout_rate (float) – Dropout rate for the hidden layers.
weight_decay (float) – Weight decay for the optimizer.
act_fn (str, optional) – Activation function for the hidden layers. Defaults to ‘learned_sigmoid’.
loss (str, optional) – Loss function to use. Defaults to ‘mae’.
max_epochs (int, optional) – Maximum number of epochs to train the model. Defaults to 5000.
log_dir (str, optional) – Directory to save the pytorch lightning logs. If set to None, the logs are saved in the current directory. Defaults to None.
transform (callable, optional) – Transform to apply to the output features, from the sunbird.data.transforms or sunbird.data.transforms_array modules. Defaults to None.
- Returns:
Validation loss of the model.
- Return type:
float