Source code for enerzyme.tasks.lightning_utils

from collections import defaultdict
from typing import Any, Optional, Iterable, Callable, Dict, Literal, Union, Tuple, List
import numpy as np
import torch
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch_ema import ExponentialMovingAverage
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from .monitor import Monitor
from ..data.transform import Transform
from .metrics import Metrics
from .batch import _decorate_batch_output


[docs] class CollectOutputCallback(L.Callback):
[docs] def __init__(self): super().__init__()
def _collect_output(self, outputs: Dict[Literal["raw_output", "loss"], Union[Dict[str, Tensor], float]], batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]] ): net_input, net_target = batch y_pred, y_truth = _decorate_batch_output(outputs["raw_output"], net_input, net_target) result = {"y_pred": y_pred, "y_truth": y_truth, "loss": outputs["loss"]} return result def _reduce_outputs(self, step_outputs: List[Dict[str, Any]], transform: Transform, metrics: Metrics, ): y_preds = defaultdict(list) y_truths = defaultdict(list) total_loss = np.mean([output["loss"].item() for output in step_outputs]) for output in step_outputs: for k, v in output["y_pred"].items(): y_preds[k].extend(v) for k, v in output["y_truth"].items(): y_truths[k].extend(v) if transform is not None: transform.inverse_transform(y_preds) transform.inverse_transform(y_truths) metric_score = metrics.cal_metric(y_truths, y_preds) step_outputs.clear() return y_preds, y_truths, total_loss, metric_score
[docs] def on_validation_batch_end(self, trainer: L.Trainer, pl_module: L.LightningModule, outputs: Dict[Literal["raw_output", "loss"], Union[Dict[str, Tensor], float]], batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]], batch_idx: int, dataloader_idx: int=0 ) -> None: pl_module.validation_step_outputs.append(self._collect_output(outputs, batch))
[docs] def on_validation_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None: _, _, total_loss, metric_score = self._reduce_outputs( pl_module.validation_step_outputs, pl_module.transform, pl_module.metrics ) pl_module.log("val_loss", total_loss, sync_dist=True) pl_module.log("lr", pl_module.optimizer.param_groups[0]["lr"], sync_dist=True) pl_module.log("last_wait_count", trainer.early_stopping_callback.wait_count) pl_module.log("last_best_score", trainer.early_stopping_callback.best_score) pl_module.log_dict(metric_score, sync_dist=True)
[docs] def on_test_batch_end(self, trainer: L.Trainer, pl_module: L.LightningModule, outputs: Dict[Literal["raw_output", "loss"], Union[Dict[str, Tensor], float]], batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]], batch_idx: int, dataloader_idx: int=0 ) -> None: pl_module.test_step_outputs.append(self._collect_output(outputs, batch))
[docs] def on_test_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None: y_preds, y_truths, total_loss, metric_score = self._reduce_outputs( pl_module.test_step_outputs, pl_module.transform, pl_module.metrics, ) pl_module.log("test_loss", total_loss, sync_dist=True) pl_module.log_dict(metric_score, sync_dist=True) pl_module.test_result = {"y_pred": y_preds, "y_truth": y_truths, "metric_score": metric_score}
[docs] class MonitorCallback(L.Callback):
[docs] def __init__(self, monitor: Monitor): super().__init__() self.monitor = monitor
[docs] def on_validation_batch_end(self, trainer: L.Trainer, pl_module: L.LightningModule, outputs: Dict[Literal["raw_output", "loss"], Union[Dict[str, Tensor], float]], batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]], batch_idx: int, dataloader_idx: int=0 ) -> None: self.monitor.collect(outputs["raw_output"])
[docs] def on_validation_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None: summary_dict = self.monitor._summary() for term, stats in summary_dict.items(): for stat, value in stats.items(): pl_module.log(f"test_{term}_{stat}", value, sync_dist=True) self.monitor._reset()
[docs] def on_test_batch_end(self, pl_module: L.LightningModule, outputs: Dict[Literal["raw_output", "loss"], Union[Dict[str, Tensor], float]], batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]], batch_idx: int, dataloader_idx: int=0 ) -> None: self.monitor.collect(outputs["raw_output"])
[docs] def on_test_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None: summary_dict = self.monitor._summary() for term, stats in summary_dict.items(): for stat, value in stats.items(): pl_module.log(f"test_{term}_{stat}", value, sync_dist=True) self.monitor._reset()
[docs] class EMACallback(L.Callback):
[docs] def __init__(self, use_ema: bool, ema_decay: Optional[float]=None, ema_use_num_updates: Optional[bool]=None): super().__init__() self.use_ema = use_ema self.ema_decay = ema_decay self.ema_use_num_updates = ema_use_num_updates
[docs] def setup(self, trainer: L.Trainer, pl_module: L.LightningModule, stage: str): if self.use_ema: self.ema = ExponentialMovingAverage(pl_module.model.to(pl_module.device).parameters(), self.ema_decay, self.ema_use_num_updates) else: self.ema = ExponentialMovingAverage(pl_module.model.to(pl_module.device).parameters(), 1, True)
[docs] def on_train_batch_end(self, trainer: L.Trainer, pl_module: L.LightningModule, outputs: Any, batch: Any, batch_idx: int): if self.use_ema: self.ema.update()
[docs] def on_validation_epoch_start(self, trainer: L.Trainer, pl_module: L.LightningModule): #self.ema.average_parameters() if self.use_ema: self.ema.store() self.ema.copy_to()
[docs] def on_validation_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule): if self.use_ema: self.ema.restore()
[docs] def on_test_epoch_start(self, trainer: L.Trainer, pl_module: L.LightningModule): if self.use_ema: self.ema.store() self.ema.copy_to()
[docs] def on_test_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule): if self.use_ema: self.ema.restore()
[docs] def state_dict(self): if self.use_ema: return self.ema.state_dict() else: return dict()
[docs] def load_state_dict(self, state_dict: dict): if self.use_ema and state_dict: self.ema.load_state_dict(state_dict) if not self.use_ema: self.ema.copy_to(self.model.parameters())
[docs] class LightningModel(L.LightningModule):
[docs] def __init__(self, model: Module, loss_terms: Iterable[Callable], dump_dir: str, optimizer: Optimizer, scheduler: LRScheduler, monitor: Monitor, transform: Transform, metrics: Metrics, use_ema: bool, ema_decay: float, ema_use_num_updates: int, dump_interval: int ): super().__init__() self.model = model self.loss_terms = loss_terms self.dump_dir = dump_dir self.optimizer = optimizer self.scheduler = scheduler self.monitor = monitor self.transform = transform self.metrics = metrics self.ema_decay = ema_decay self.ema_use_num_updates = ema_use_num_updates self.use_ema = use_ema self.validation_step_outputs = [] self.test_step_outputs = [] self.test_result = None self.dump_interval = dump_interval
[docs] def training_step(self, batch, batch_idx): net_input, net_target = batch loss = 0 with torch.set_grad_enabled(True): output = self.model(net_input) for loss_term in self.loss_terms.values(): loss += loss_term(output, net_target) result = {"loss": loss} return result
def _prediction_step(self, batch): net_input, net_target = batch with torch.enable_grad(): net_input["Ra"].requires_grad_(True) output = self.model(net_input) loss = 0 with torch.no_grad(): for loss_term in self.loss_terms.values(): loss += loss_term(output, net_target) result = {"loss": loss, "raw_output": output} return result
[docs] def validation_step(self, batch, batch_idx): return self._prediction_step(batch)
[docs] def test_step(self, batch, batch_idx): return self._prediction_step(batch)
[docs] def configure_optimizers(self): return { "optimizer": self.optimizer, "lr_scheduler": { "scheduler": self.scheduler, "interval": "step", "frequency": 1 } }
[docs] def configure_callbacks(self): best_checkpoint_callback = ModelCheckpoint( dirpath=self.dump_dir, monitor="_judge_score", mode="min", save_top_k=1, filename="model_best" ) best_checkpoint_callback.FILE_EXTENSION = ".pth" last_checkpoint_callback = ModelCheckpoint( dirpath=self.dump_dir, filename="model_last" ) last_checkpoint_callback.FILE_EXTENSION = ".pth" collect_output_callback = CollectOutputCallback() callbacks = [best_checkpoint_callback, last_checkpoint_callback, collect_output_callback] if self.dump_interval > 0: epoch_checkpoint_callback = ModelCheckpoint( dirpath=self.dump_dir, filename="model_{epoch}", every_n_epochs=self.dump_interval, save_top_k=-1 ) epoch_checkpoint_callback.FILE_EXTENSION = ".pth" callbacks.append(epoch_checkpoint_callback) if self.monitor is not None: monitor_callback = MonitorCallback(self.monitor) callbacks.append(monitor_callback) if self.use_ema: callbacks.append(EMACallback( self.use_ema, self.ema_decay, self.ema_use_num_updates )) return callbacks