enerzyme.tasks.trainer.Trainer#
- class enerzyme.tasks.trainer.Trainer(out_dir: str = None, metric_config: Dict = {}, **params)[source]#
Bases:
object- __init__(out_dir: str = None, metric_config: Dict = {}, **params) None[source]#
The trainer class for training and evaluating the model.
Params:#
- fit_predict(model: Module, pretrain_path: str | None, train_dataset: Dataset, valid_dataset: Dataset | None, loss_terms: Iterable[Callable], dump_dir: str, transform: Transform, test_dataset: Dataset | None = None, model_rank: int | None = None, max_epoch_per_iter: int = -1, meta_state_dict: Dict = {}, refresh_patience: bool = False, refresh_best_score: bool = False) Dict[Literal['y_pred', 'y_truth', 'metric_score'], Any][source]#
Train the model on the training set, validate it on the validation set, and test the model on the test set.
Params:#
- model: Module
The model to train
- pretrain_path: Optional[str]
The path to the pretrained model or the checkpoint of the model
- train_dataset: Dataset
The training dataset.
- valid_dataset: Optional[Dataset]
The validation dataset. If not provided, the model will not be validated.
- loss_terms: Iterable[Callable]
The loss functions with multiple terms to use.
- dump_dir: str
The directory to save the model
- transform: Transform
The data transform in preprocessing. The inverse transform will be applied to the prediction results when calculating the metrics during validation and testing.
- test_dataset: Optional[Dataset]
The test dataset. If not provided, the model will not be tested.
- model_rank: Optional[int]
Only used in deep ensemble training. The rank of the model.
- max_epoch_per_iter: int
Only used in active learning. The maximum number of epochs per active learning iteration.
- meta_state_dict: Dict
Only used in active learning checkpointing. The meta state dictionary.
- refresh_patience: bool
Whether to refresh the patience when loading the checkpoint.
- refresh_best_score: bool
Whether to refresh the best score when loading the checkpoint.
Returns:#
The prediction results on the test set. A dictionary with the following keys:
- y_pred
The predicted results.
- y_truth
The true results.
- metric_score
The metrics score based on the predicted and true results.
- load_state_dict(model: Module, optimizer: Optimizer | None = None, scheduler: LRScheduler | None = None, pretrain_path: str | None = None, ema: ExponentialMovingAverage | None = None, inference: bool = False, strict: bool = True) None[source]#
- predict(model: Module, dataset: Dataset, loss_terms: Iterable[Callable], dump_dir: str, transform: Transform, epoch: int = 1, load_model: bool = False, model_rank: str | None = None, test_mode: bool = False) Dict[Literal['y_pred', 'y_truth', 'val_loss', 'metric_score'], Any][source]#