FeatureScaler

class franken.rf.scaler.FeatureScaler(input_dim, statistics, scale_by_Z, num_species)

Bases: Module

Mean and standard deviation scaler for GNN features.

Can be initialized from a franken.rf.scaler.Statistics instance, and when called will scale GNN features. Supports both global and per-species normalization.

Parameters:
forward(descriptors, atomic_numbers=None)

Scale the given features to have zero-mean and unit standard deviation.

Parameters:
  • descriptors (torch.Tensor) – GNN features

  • atomic_numbers (Tensor | None) – Atomic numbers for each atom. This can be left to None unless the feature-scaler has been configured to perform per-species normalization. Defaults to None.

Returns:

Normalized GNN features

Return type:

torch.Tensor

set_from_statistics(statistics)

Set the mean and standard deviation statistics for scaling.

Parameters:

statistics (Statistics) – Instance of franken.rf.scaler.Statistics from which the feature mean and standard deviation can be fetched.