BaseTrainer
- class franken.trainers.BaseTrainer(train_dataloader, log_dir=None, save_every_model=True, device='cpu', dtype=torch.float32)
Bases:
ABCBase trainer class. Requires
fit()andevaluate()methods.- abstract evaluate(model, dataloader, log_collection, all_weights, metrics)
Evaluate a fitted model by computing metrics on a validation dataset.
- Parameters:
model (FrankenPotential) – The model which defines GNN and random features.
dataloader (torch.utils.data.DataLoader) – Evaluation will run the model on each configuration in the dataloader, computing averaged metrics.
log_collection (LogCollection) – Log object as output by the
fit()method. Metric values will be added to the logs and the same object will be returned by this method.all_weights (torch.Tensor) – The weights as output by the
fit()method.metrics (list[str]) – List of metrics which should be computed.
- Returns:
Logs which contain all parameters related to the fitting, as well as timings and metrics.
- Return type:
logs (LogCollection)
- abstract fit(model, solver_params)
Fit a given franken model on the training set.
- Parameters:
model (FrankenPotential) – The model which defines GNN and random features.
solver_params (dict) – Parameters for the solver which actually performs the fit.
- Returns:
Logs which contain all parameters related to the fitting, as well as timings.
Weights which were learned during the fit.
- Return type:
tuple[LogCollection, torch.Tensor]
- get_statistics(model)
Compute statistics on the training dataset with the provided model
- Parameters:
model (FrankenPotential) – Franken model from which the attached GNN is used to compute the features on atomic configurations.
- Returns:
A tuple containing an object of type
franken.rf.scaler.Statisticscontaining the dataset statistics, and a dictionary containing the GNN-backbone hyperparameters used when computing dataset features.- Return type: