enerzyme.tasks.lightning_utils.EMACallback#
- class enerzyme.tasks.lightning_utils.EMACallback(use_ema: bool, ema_decay: float | None = None, ema_use_num_updates: bool | None = None)[source]#
Bases:
Callback- __init__(use_ema: bool, ema_decay: float | None = None, ema_use_num_updates: bool | None = None)[source]#
- load_state_dict(state_dict: dict)[source]#
Called when loading a checkpoint, implement to reload callback state given callback’s
state_dict.- Args:
state_dict: the callback state returned by
state_dict.
- on_test_epoch_end(trainer: Trainer, pl_module: LightningModule)[source]#
Called when the test epoch ends.
- on_test_epoch_start(trainer: Trainer, pl_module: LightningModule)[source]#
Called when the test epoch begins.
- on_train_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int)[source]#
Called when the train batch ends.
- Note:
The value
outputs["loss"]here will be the normalized value w.r.taccumulate_grad_batchesof the loss returned fromtraining_step.
- on_validation_epoch_end(trainer: Trainer, pl_module: LightningModule)[source]#
Called when the val epoch ends.
- on_validation_epoch_start(trainer: Trainer, pl_module: LightningModule)[source]#
Called when the val epoch begins.