enerzyme.tasks.trainer.Trainer#

class enerzyme.tasks.trainer.Trainer(out_dir: str = None, metric_config: Dict = {}, **params)[source]#

Bases: object

__init__(out_dir: str = None, metric_config: Dict = {}, **params) None[source]#

The trainer class for training and evaluating the model.

Params:#

out_dir: str

The directory to save the model.

metric_config: dict

The configuration for the Metrics class.

**params: dict

The configuration for the trainer.

decorate_batch_input(batch)[source]#
decorate_batch_output(output, features, targets)[source]#
fit_predict(model: Module, pretrain_path: str | None, train_dataset: Dataset, valid_dataset: Dataset | None, loss_terms: Iterable[Callable], dump_dir: str, transform: Transform, test_dataset: Dataset | None = None, model_rank: int | None = None, max_epoch_per_iter: int = -1, meta_state_dict: Dict = {}, refresh_patience: bool = False, refresh_best_score: bool = False) Dict[Literal['y_pred', 'y_truth', 'metric_score'], Any][source]#

Train the model on the training set, validate it on the validation set, and test the model on the test set.

Params:#

model: Module

The model to train

pretrain_path: Optional[str]

The path to the pretrained model or the checkpoint of the model

train_dataset: Dataset

The training dataset.

valid_dataset: Optional[Dataset]

The validation dataset. If not provided, the model will not be validated.

loss_terms: Iterable[Callable]

The loss functions with multiple terms to use.

dump_dir: str

The directory to save the model

transform: Transform

The data transform in preprocessing. The inverse transform will be applied to the prediction results when calculating the metrics during validation and testing.

test_dataset: Optional[Dataset]

The test dataset. If not provided, the model will not be tested.

model_rank: Optional[int]

Only used in deep ensemble training. The rank of the model.

max_epoch_per_iter: int

Only used in active learning. The maximum number of epochs per active learning iteration.

meta_state_dict: Dict

Only used in active learning checkpointing. The meta state dictionary.

refresh_patience: bool

Whether to refresh the patience when loading the checkpoint.

refresh_best_score: bool

Whether to refresh the best score when loading the checkpoint.

Returns:#

The prediction results on the test set. A dictionary with the following keys:

y_pred

The predicted results.

y_truth

The true results.

metric_score

The metrics score based on the predicted and true results.

load_state_dict(model: Module, optimizer: Optimizer | None = None, scheduler: LRScheduler | None = None, pretrain_path: str | None = None, ema: ExponentialMovingAverage | None = None, inference: bool = False, strict: bool = True) None[source]#
predict(model: Module, dataset: Dataset, loss_terms: Iterable[Callable], dump_dir: str, transform: Transform, epoch: int = 1, load_model: bool = False, model_rank: str | None = None, test_mode: bool = False) Dict[Literal['y_pred', 'y_truth', 'val_loss', 'metric_score'], Any][source]#
save_state_dict(model: Module, optimizer: Optimizer, scheduler: LRScheduler, dump_dir: str, ema: ExponentialMovingAverage | None = None, suffix='last', model_rank=None, epoch: int | None = None, best_score: float | None = None, best_epoch: int | None = None)[source]#
to_device(batch)[source]#