BaseTrainer

class franken.trainers.BaseTrainer(train_dataloader, log_dir=None, save_every_model=True, device='cpu', dtype=torch.float32)

Bases: ABC

Base trainer class. Requires fit() and evaluate() methods.

abstract evaluate(model, dataloader, log_collection, all_weights, metrics)

Evaluate a fitted model by computing metrics on a validation dataset.

Parameters:
  • model (FrankenPotential) – The model which defines GNN and random features.

  • dataloader (torch.utils.data.DataLoader) – Evaluation will run the model on each configuration in the dataloader, computing averaged metrics.

  • log_collection (LogCollection) – Log object as output by the fit() method. Metric values will be added to the logs and the same object will be returned by this method.

  • all_weights (torch.Tensor) – The weights as output by the fit() method.

  • metrics (list[str]) – List of metrics which should be computed.

Returns:

Logs which contain all parameters related to the fitting, as well as timings and metrics.

Return type:

logs (LogCollection)

abstract fit(model, solver_params)

Fit a given franken model on the training set.

Parameters:
  • model (FrankenPotential) – The model which defines GNN and random features.

  • solver_params (dict) – Parameters for the solver which actually performs the fit.

Returns:

  • Logs which contain all parameters related to the fitting, as well as timings.

  • Weights which were learned during the fit.

Return type:

tuple[LogCollection, torch.Tensor]

get_statistics(model)

Compute statistics on the training dataset with the provided model

Parameters:

model (FrankenPotential) – Franken model from which the attached GNN is used to compute the features on atomic configurations.

Returns:

A tuple containing an object of type franken.rf.scaler.Statistics containing the dataset statistics, and a dictionary containing the GNN-backbone hyperparameters used when computing dataset features.

Return type:

Tuple[Statistics, dict]