enerzyme.tasks.lightning_utils.CollectOutputCallback#

class enerzyme.tasks.lightning_utils.CollectOutputCallback[source]#

Bases: Callback

__init__()[source]#
on_test_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Dict[Literal['raw_output', 'loss'], Dict[str, Tensor] | float], batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]], batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the test batch ends.

on_test_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the test epoch ends.

on_validation_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Dict[Literal['raw_output', 'loss'], Dict[str, Tensor] | float], batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]], batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the validation batch ends.

on_validation_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the val epoch ends.