enerzyme.tasks.lightning_utils.EMACallback#

class enerzyme.tasks.lightning_utils.EMACallback(use_ema: bool, ema_decay: float | None = None, ema_use_num_updates: bool | None = None)[source]#

Bases: Callback

__init__(use_ema: bool, ema_decay: float | None = None, ema_use_num_updates: bool | None = None)[source]#
load_state_dict(state_dict: dict)[source]#

Called when loading a checkpoint, implement to reload callback state given callback’s state_dict.

Args:

state_dict: the callback state returned by state_dict.

on_test_epoch_end(trainer: Trainer, pl_module: LightningModule)[source]#

Called when the test epoch ends.

on_test_epoch_start(trainer: Trainer, pl_module: LightningModule)[source]#

Called when the test epoch begins.

on_train_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int)[source]#

Called when the train batch ends.

Note:

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

on_validation_epoch_end(trainer: Trainer, pl_module: LightningModule)[source]#

Called when the val epoch ends.

on_validation_epoch_start(trainer: Trainer, pl_module: LightningModule)[source]#

Called when the val epoch begins.

setup(trainer: Trainer, pl_module: LightningModule, stage: str)[source]#

Called when fit, validate, test, predict, or tune begins.

state_dict()[source]#

Called when saving a checkpoint, implement to generate callback’s state_dict.

Returns:

A dictionary containing callback state.