RandomFeaturesTrainer

class franken.trainers.RandomFeaturesTrainer(train_dataloader, random_features_normalization='leading_eig', log_dir=None, save_every_model=True, device='cuda:0', dtype=torch.float32, save_fmaps=True)

Bases: BaseTrainer

Main class which groups training and evaluation functionality for franken models.

Parameters:
  • train_dataloader (torch.utils.data.DataLoader) – Dataloader which iterates over the training set.

  • random_features_normalization (Literal["leading_eig"] | None) – How to normalize the covariance matrices formed by random-features. Defaults to “leading_eig”.

  • log_dir (Path | None) – Directory where to save logs and models. If not specified, no logs will be saved. Defaults to None.

  • save_every_model (bool) – Model fitting with this class is done simultaneously for a list of solver parameters. This argument controls the behavior of model saving: if set to True, the models corresponding to all solver parameters will be saved, otherwise only the ‘best’ model among them (according to some validation set) will be saved. Defaults to True.

  • device (device | str | int) – PyTorch device on which computations are performed. Defaults to “cuda:0”.

  • dtype (str | torch.dtype) – Data-type for solver operations. Random features will be computed in float32, and then converted to float64 if requested. Defaults to torch.float32.

  • save_fmaps (bool) – Whether or not to save feature-maps for the training set. Saving them requires extra memory (linear in the training-set size), but speeds up the evaluate() path on training data. Defaults to True.

fit(model, solver_params)

Fit a given franken model on the training set.

Parameters:
  • model (FrankenPotential) – The model which defines GNN and random features.

  • solver_params (dict) – Parameters for the solver which actually performs the fit. This argument allows to specify multiple parameters, for each of which we will perform a fit. For example, passing {"l2_penalty": [1e-6, 1e-4], "force_weight": [0.5]} will result in two different models, one with l2_penalty=1e-6, force_weight=0.5 and one with l2_penalty=1e-4, force_weight=0.5. This way of specifying solver parameters allows to easily perform a grid-search.

Returns:

The fitting logs, together with the learned weights.

Return type:

tuple[LogCollection, torch.Tensor]

Note

More information about the available solver parameters can be found under the solve() method.