Source code for enerzyme.tasks.trainer

from functools import partial
from typing import Iterable, Optional, Callable, Dict, Any, Literal
from collections import defaultdict
import time, os, logging, contextlib
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader, Dataset
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.nn import Module
from torch.nn.utils import clip_grad_norm_
from torch_ema import ExponentialMovingAverage
try:
    logging.getLogger('tensorflow').disabled = True
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
except:
    pass
from transformers.optimization import get_scheduler
import numpy as np
from .splitter import Splitter
from .batch import _decorate_batch_input, _decorate_batch_output, _to_device, _decorate_pyg_batch_input, _pyg_to_device
from .monitor import Monitor
from .optimizer import get_optimizer, get_optimizer_config
from ..data.transform import Transform
from ..utils import logger
from .metrics import Metrics


DTYPE_MAPPING = {
    "float64": torch.float64,
    "float32": torch.float32,
    "float": torch.float32,
    "double": torch.float64,
    "single": torch.float32
}


def _modify_lightning_state_dict(lightning_ckpt_path: str, patience: int) -> None:
    '''
    Modify the patience of the EarlyStopping callback in the Lightning checkpoint.

    .. danger::

    This function should be only used in the world rank 0 during distributed training. Otherwise, the checkpoint will be corrupted.

    Params:
    ----------
    lightning_ckpt_path: str
        The path to the Lightning checkpoint.
        
    patience: int
        The patience to set.

    Returns:
    ----------
    None
    '''
    loaded_info = torch.load(lightning_ckpt_path, map_location="cpu")
    if "callbacks" in loaded_info:
        for k, v in loaded_info["callbacks"].items():
            if k.startswith("EarlyStopping"):
                if 'patience' in v and v['patience'] < patience:
                    v['patience'] = patience
                    logger.info(f"increase patience of EarlyStopping to {patience}")
                else:
                    return
    torch.save(loaded_info, lightning_ckpt_path)


def _convert_lightning_model_state_dict(lightning_model_state_dict: Dict) -> Dict:
    model_state_dict = dict()
    for k, v in lightning_model_state_dict.items():
        if k.startswith("model."):
            model_state_dict[k[6:]] = v
        else:
            model_state_dict[k] = v
    return model_state_dict


def _convert_lightning_state_dict(lightning_loaded_info: Dict) -> Dict:
    loaded_info = dict()
    if "state_dict" in lightning_loaded_info:
        loaded_info["model_state_dict"] = _convert_lightning_model_state_dict(lightning_loaded_info["state_dict"])
    if "epoch" in lightning_loaded_info:
        loaded_info["epoch"] = lightning_loaded_info["epoch"]
    if "lr_schedulers" in lightning_loaded_info:
        loaded_info["scheduler_state_dict"] = lightning_loaded_info["lr_schedulers"][0]
    if 'optimizer_states' in lightning_loaded_info:
        loaded_info["optimizer_state_dict"] = lightning_loaded_info["optimizer_states"][0]
    if "callbacks" in lightning_loaded_info:
        for k, v in lightning_loaded_info["callbacks"].items():
            if k == "EMACallback":
                loaded_info["ema_state_dict"] = v
            elif k.startswith("EarlyStopping"):
                if 'best_score' in v:
                    loaded_info["best_score"] = v['best_score']
                if 'wait_count' in v and "epoch" in loaded_info:
                    loaded_info["best_epoch"] = loaded_info["epoch"] - v['wait_count']
    return loaded_info


def _load_state_dict(model: Module, device: Optional[torch.device]=None, pretrain_path: Optional[str]=None, ema: Optional[ExponentialMovingAverage]=None, inference: bool=False, optimizer: Optional[Optimizer]=None, scheduler: Optional[LRScheduler]=None, strict: bool=True) -> Dict:
    other_info = dict()
    if pretrain_path is None:
        return other_info
    loaded_info = torch.load(pretrain_path, map_location=device)
    if 'pytorch-lightning_version' in loaded_info:
        loaded_info = _convert_lightning_state_dict(loaded_info)
    if ema is not None and "ema_state_dict" in loaded_info:
        model.load_state_dict(loaded_info["model_state_dict"])
        ema.load_state_dict(loaded_info["ema_state_dict"])
        logger.info(f"loading ema state dict from {pretrain_path}...")
    else:
        if inference and "ema_state_dict" in loaded_info:
            model.load_state_dict(loaded_info["model_state_dict"])
            tmp_ema = ExponentialMovingAverage(model.parameters(), decay=1, use_num_updates=True)
            tmp_ema.load_state_dict(loaded_info["ema_state_dict"])
            tmp_ema.copy_to(model.parameters())
            logger.info(f"loading averaged model state dict from {pretrain_path}...")
        else:
            model.load_state_dict(loaded_info["model_state_dict"], strict=strict)
            other_info["model_state_dict"] = loaded_info["model_state_dict"]
            logger.info(f"loading model state dict from {pretrain_path}...")
    if not inference:
        if optimizer is not None and "optimizer_state_dict" in loaded_info:
            optimizer.load_state_dict(loaded_info["optimizer_state_dict"])
            logger.info(f"loading optimizer state dict from {pretrain_path}...")
        if scheduler is not None and "scheduler_state_dict" in loaded_info:
            scheduler.load_state_dict(loaded_info["scheduler_state_dict"])
            logger.info(f"loading scheduler state dict from {pretrain_path}...")
        if "epoch" in loaded_info:
            other_info["epoch"] = loaded_info["epoch"]
        if "best_epoch" in loaded_info:
            other_info["best_epoch"] = loaded_info["best_epoch"]
        if "best_score" in loaded_info:
            other_info["best_score"] = loaded_info["best_score"]
    # if "torch_rng_state" in loaded_info:
    #     torch.random.set_rng_state(loaded_info["torch_rng_state"])
    #     logger.info(f"loading torch random generator state from {pretrain_path}...")
    # if "torch_cuda_rng_state_all" in loaded_info:
    #     torch.cuda.random.set_rng_state_all(loaded_info["torch_cuda_rng_state_all"])
    #     logger.info(f"loading torch cuda random generator state from {pretrain_path}...")
    # if "np_rng_state" in loaded_info:
    #     np.random.set_state(loaded_info["np_rng_state"])
    #     logger.info(f"loading numpy random generator state from {pretrain_path}...")
    return other_info


[docs] class Trainer:
[docs] def __init__(self, out_dir: str=None, metric_config: Dict=dict(), **params) -> None: ''' The trainer class for training and evaluating the model. Params: ---------- out_dir: str The directory to save the model. metric_config: dict The configuration for the :doc:`Metrics <enerzyme.tasks.metrics.Metrics>` class. **params: dict The configuration for the trainer. ''' self.batch_size = params.get('batch_size', 8) self.pyg = params.get("pyg", False) self.lightning = params.get("lightning", False) self.patience = params.get('patience', 50) self.max_norm = params.get('max_norm', -1) self.config = params self.out_dir = out_dir self.metric_config = metric_config self.metrics = Metrics(metric_config) self.optimizer_name, self.optimizer_hyper_params = get_optimizer_config(**params) self.splitter = Splitter(**params["Splitter"]) if "Monitor" in params: self.monitor = Monitor(**params["Monitor"]) else: self.monitor = None self.seed = params.get('seed', 114514) self.inference_batch_size = params.get('inference_batch_size', self.batch_size) self.max_epochs = params.get('max_epochs', 1000) self.warmup_ratio = params.get('warmup_ratio', 0.01) self.cuda = params.get('cuda', False) self.schedule = params.get('schedule', "linear") self.data_in_memory = params.get("data_in_memory", True) self.use_ema = params.get("use_ema", True) self.ema_decay = params.get("ema_decay", 0.999) self.ema_use_num_updates = params.get("ema_use_num_updates", True) self.dtype = DTYPE_MAPPING[params.get('dtype', "float32")] self.committee_size = params.get("committee_size", 1) self.dump_interval = params.get("dump_interval", -1) self.active_learning_params = params.get("active_learning_params", None) if self.active_learning_params is not None and self.active_learning_params.get("active", False): self.active_learning = True else: self.active_learning = False self.resume = params.get("resume", 1) self.refresh_best_score = params.get("refresh_best_score", None) self.refresh_patience = params.get("refresh_patience", None) self.freeze_pretrain_weights = params.get("freeze_pretrain_weights", False) non_target_features = params.get("non_target_features", []) self.num_workers = params.get("num_workers", 0) if self.num_workers <= 0: if "SLURM_NTASKS" in os.environ: self.num_workers = max(1, int(os.environ["SLURM_NTASKS"]) // 2 - 1) logger.info(f"using {self.num_workers} workers for dataloader given {os.environ['SLURM_NTASKS']} tasks") else: self.num_workers = max(1, os.cpu_count() // 2 - 1) logger.info(f"using {self.num_workers} workers for dataloader given {os.cpu_count()} cpus") else: logger.info(f"using {self.num_workers} workers for dataloader") if isinstance(non_target_features, list): self.non_target_features = non_target_features elif isinstance(non_target_features, str): self.non_target_features = [non_target_features] else: raise ValueError(f"non_target_features must be a list or a string, but got {type(non_target_features)}") if self.lightning: self.device = None import lightning as L from lightning.pytorch.callbacks import EarlyStopping early_stopping_callback = EarlyStopping( monitor="_judge_score", mode="min", patience=self.patience, min_delta=0 ) self.lightning_trainer = L.Trainer( default_root_dir=out_dir, accelerator="auto", callbacks=[early_stopping_callback], devices="auto", num_nodes=int(os.environ.get("SLURM_NNODES", 1)), strategy="auto", gradient_clip_val=self.max_norm if self.max_norm > 0 else None, inference_mode=False, max_epochs=self.max_epochs ) else: if torch.cuda.is_available(): logger.info("GPU found!") self.device = torch.device("cuda:0" if self.cuda else "cpu") else: logger.info("GPU not found, turn to CPU!") self.device = torch.device("cpu")
[docs] def decorate_batch_input(self, batch): if self.pyg: return _decorate_pyg_batch_input(batch, self.dtype, self.device) else: return _decorate_batch_input(batch, self.dtype, self.device)
[docs] def to_device(self, batch): if self.pyg: return _pyg_to_device(batch, self.device) else: return _to_device(batch, self.device)
[docs] def decorate_batch_output(self, output, features, targets): return _decorate_batch_output(output, features, targets, self.non_target_features)
def _set_seed(self, seed): """function used to set a random seed Arguments: seed {int} -- seed number, will set to torch and numpy """ torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed)
[docs] def load_state_dict( self, model: Module, optimizer: Optional[Optimizer]=None, scheduler: Optional[LRScheduler]=None, pretrain_path: Optional[str]=None, ema: Optional[ExponentialMovingAverage]=None, inference: bool=False, strict: bool=True ) -> None: return _load_state_dict(model, self.device, pretrain_path, ema, inference, optimizer, scheduler, strict)
[docs] def save_state_dict(self, model: Module, optimizer: Optimizer, scheduler: LRScheduler, dump_dir: str, ema: Optional[ExponentialMovingAverage]=None, suffix="last", model_rank=None, epoch: Optional[int]=None, best_score: Optional[float]=None, best_epoch: Optional[int]=None): if model_rank is None: model_rank = '' os.makedirs(dump_dir, exist_ok=True) info = {"model_state_dict": model.state_dict()} if ema is not None: info["ema_state_dict"] = ema.state_dict() info["optimizer_state_dict"] = optimizer.state_dict() info["scheduler_state_dict"] = scheduler.state_dict() if epoch is not None: info["epoch"] = epoch if best_score is not None: info["best_score"] = best_score if best_epoch is not None: info["best_epoch"] = best_epoch # info["torch_rng_state"] = torch.random.get_rng_state() # if self.cuda: # info["torch_cuda_rng_state_all"] = torch.cuda.random.get_rng_state_all() # info["np_rng_state"] = np.random.get_state() torch.save(info, os.path.join(dump_dir, f'model{model_rank}_{suffix}.pth'))
[docs] def fit_predict(self, model: Module, pretrain_path: Optional[str], train_dataset: Dataset, valid_dataset: Optional[Dataset], loss_terms: Iterable[Callable], dump_dir: str, transform: Transform, test_dataset: Optional[Dataset]=None, model_rank: Optional[int]=None, max_epoch_per_iter: int=-1, meta_state_dict: Dict=dict(), refresh_patience: bool=False, refresh_best_score: bool=False ) -> Dict[Literal["y_pred", "y_truth", "metric_score"], Any]: ''' Train the model on the training set, validate it on the validation set, and test the model on the test set. Params: ---------- model: Module The model to train pretrain_path: Optional[str] The path to the pretrained model or the checkpoint of the model train_dataset: Dataset The training dataset. valid_dataset: Optional[Dataset] The validation dataset. If not provided, the model will not be validated. loss_terms: Iterable[Callable] The loss functions with multiple terms to use. dump_dir: str The directory to save the model transform: Transform The data transform in preprocessing. The inverse transform will be applied to the prediction results when calculating the metrics during validation and testing. test_dataset: Optional[Dataset] The test dataset. If not provided, the model will not be tested. model_rank: Optional[int] Only used in deep ensemble training. The rank of the model. max_epoch_per_iter: int Only used in active learning. The maximum number of epochs per active learning iteration. meta_state_dict: Dict Only used in active learning checkpointing. The meta state dictionary. refresh_patience: bool Whether to refresh the patience when loading the checkpoint. refresh_best_score: bool Whether to refresh the best score when loading the checkpoint. Returns: ---------- The prediction results on the test set. A dictionary with the following keys: y_pred The predicted results. y_truth The true results. metric_score The metrics score based on the predicted and true results. ''' if self.refresh_best_score is not None: refresh_best_score = self.refresh_best_score if self.refresh_patience is not None: refresh_patience = self.refresh_patience self._set_seed(self.seed + (model_rank if model_rank is not None else 0)) train_dataloader = DataLoader( dataset=train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=self.decorate_batch_input, num_workers=self.num_workers, drop_last=True ) if self.lightning: num_training_steps = len(train_dataloader) * self.max_epochs // self.lightning_trainer.world_size else: num_training_steps = len(train_dataloader) * self.max_epochs num_warmup_steps = int(num_training_steps * self.warmup_ratio) optimizer = get_optimizer(self.optimizer_name, model, self.optimizer_hyper_params) scheduler = get_scheduler(self.schedule, optimizer, num_warmup_steps, num_training_steps) if self.lightning: logger.info("Using Lightning Trainer") from .lightning_utils import LightningModel lightning_model = LightningModel( model.type(self.dtype), loss_terms, dump_dir, optimizer, scheduler, monitor=self.monitor, transform=transform, metrics=self.metrics, use_ema=self.use_ema, ema_decay=self.ema_decay, ema_use_num_updates=self.ema_use_num_updates, dump_interval=self.dump_interval ) train_dataloader = DataLoader( dataset=train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=self.decorate_batch_input, drop_last=True, pin_memory=True, num_workers=max(1, self.num_workers) ) if valid_dataset is not None: valid_dataloader = DataLoader( dataset=valid_dataset, batch_size=self.inference_batch_size, shuffle=False, collate_fn=self.decorate_batch_input, pin_memory=True, num_workers=max(1, self.num_workers) ) else: valid_dataloader = None if self.resume > 1 and pretrain_path is not None: # if the world rank is 0, modify the state dict if self.lightning_trainer.is_global_zero: _modify_lightning_state_dict(pretrain_path, self.patience) logger.info(f"Resuming from {pretrain_path}...") self.lightning_trainer.fit(lightning_model, train_dataloader, valid_dataloader, ckpt_path=pretrain_path if self.resume > 1 else None) if test_dataset is not None: test_dataloader = DataLoader( dataset=test_dataset, batch_size=self.inference_batch_size, shuffle=False, collate_fn=self.decorate_batch_input, pin_memory=True, num_workers=max(1, self.num_workers) ) self.lightning_trainer.test(model, test_dataloader) else: model = model.to(self.device).type(self.dtype) if self.use_ema: ema = ExponentialMovingAverage( model.parameters(), self.ema_decay, self.ema_use_num_updates ) else: ema = None if self.resume > 1: other_info = self.load_state_dict(model, optimizer, scheduler, pretrain_path, ema) elif self.resume == 1: other_info = self.load_state_dict(model, pretrain_path=pretrain_path, ema=ema) else: other_info = self.load_state_dict(model, pretrain_path=pretrain_path, inference=True, strict=False) if self.freeze_pretrain_weights: for name, param in model.named_parameters(): if name in other_info.get("model_state_dict", dict()): param.requires_grad = False logger.info(f"Freezing pretrained weights for {name}") if self.resume > 1 and "best_epoch" in other_info and "epoch" in other_info: wait = other_info["epoch"] - other_info["best_epoch"] if refresh_best_score: best_score = float("inf") else: best_score = other_info.get("best_score", float("inf")) start_epoch = other_info["epoch"] + 1 if (wait >= self.patience and refresh_patience) or refresh_best_score: wait = 0 else: wait = 0 if self.resume > 1: start_epoch = other_info.get("epoch", -1) + 1 else: start_epoch = 0 if valid_dataset is not None: if self.resume > 1 and not refresh_best_score: best_score = other_info.get("best_score", float("inf")) else: best_score = float("inf") else: best_score = None if self.resume > 1: max_epochs = self.max_epochs else: max_epochs = start_epoch + self.max_epochs if valid_dataset is not None: if self.resume > 1: best_epoch = other_info.get("best_epoch", start_epoch) else: best_epoch = None else: best_epoch = None epoch = start_epoch epoch_in_iter = meta_state_dict.get("epoch_in_iter", 0) if start_epoch > 0: if epoch_in_iter > 0: logger.info(f"Resuming from epoch {start_epoch + 1}, epoch {epoch_in_iter + 1} in active learning iteration") else: logger.info(f"Resuming from epoch {start_epoch + 1}...") for epoch in range(start_epoch, max_epochs): if max_epoch_per_iter > 0 and epoch_in_iter >= max_epoch_per_iter: break model = model.train() start_time = time.time() batch_bar = tqdm( total=len(train_dataloader), dynamic_ncols=True, leave=False, position=0, desc='Train', ncols=5 ) trn_loss = [] for i, batch in enumerate(train_dataloader): net_input, net_target = self.to_device(batch) loss = 0 with torch.set_grad_enabled(True): output = model(net_input) for loss_term in loss_terms.values(): loss += loss_term(output, net_target) trn_loss.append(float(loss.data)) batch_bar.set_postfix( Epoch="Epoch {}/{}".format(epoch+1, max_epochs), loss="{:.04f}".format(float(sum(trn_loss) / (i + 1))), lr="{:.04f}".format(float(optimizer.param_groups[0]['lr'])) ) # see https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#use-parameter-grad-none-instead-of-model-zero-grad-or-optimizer-zero-grad optimizer.zero_grad(set_to_none=True) loss.backward() if self.max_norm > 0: clip_grad_norm_(model.parameters(), self.max_norm) optimizer.step() if self.use_ema: ema.update() scheduler.step() batch_bar.update() batch_bar.close() total_trn_loss = np.mean(trn_loss) message = f'Epoch [{epoch + 1}/{max_epochs}]' + (f' ({epoch_in_iter + 1}/{max_epoch_per_iter})' if max_epoch_per_iter > 0 else '') + f', train_loss: {total_trn_loss:.4f}, lr: {optimizer.param_groups[0]["lr"]:.6f}' if self.use_ema: cm = ema.average_parameters() else: cm = contextlib.nullcontext() if valid_dataset is not None: with cm: predict_result = self.predict( model=model, dataset=valid_dataset, loss_terms=loss_terms, dump_dir=dump_dir, transform=transform, epoch=epoch, load_model=False, ) val_loss = predict_result["val_loss"] metric_score = predict_result["metric_score"] total_val_loss = np.mean(val_loss) _score = metric_score["_judge_score"] _metric = str(self.metrics) save_handle = partial(self.save_state_dict, model=model, optimizer=optimizer, scheduler=scheduler, dump_dir=dump_dir, ema=ema, suffix="best", model_rank=model_rank) is_early_stop, best_score, wait, saved = self._early_stop_choice(wait, best_score, metric_score, save_handle, self.patience, epoch) if saved: best_epoch = epoch message += f', val_loss: {total_val_loss:.4f}, ' + \ ", ".join([f'val_{k}: {v:.4f}' for k, v in metric_score.items() if k != "_judge_score"]) + \ f', val_judge_score ({_metric}): {_score:.4f}' + \ (f', Patience [{wait}/{self.patience}], min_val_judge_score: {best_score:.4f}' if wait else '') else: is_early_stop = False epoch_in_iter += 1 meta_state_dict.update({"epoch_in_iter": epoch_in_iter}) end_time = time.time() message += f', {(end_time - start_time):.1f}s' logger.info(message) self.save_state_dict(model, optimizer, scheduler, dump_dir, ema, "last", model_rank, epoch=epoch, best_score=best_score, best_epoch=best_epoch) if self.dump_interval > 0 and (epoch + 1) % self.dump_interval == 0: self.save_state_dict(model, optimizer, scheduler, dump_dir, ema, f"epoch={epoch}", model_rank, epoch=epoch, best_score=best_score, best_epoch=best_epoch) if is_early_stop: break meta_state_dict.update({"model_rank": model_rank + 1 if model_rank is not None else 0, "epoch_in_iter": 0}) if test_dataset is not None: if self.use_ema: cm = ema.average_parameters() else: cm = contextlib.nullcontext() with cm: predict_result = self.predict( model=model, dataset=test_dataset, loss_terms=loss_terms, dump_dir=dump_dir, transform=transform, epoch=epoch, load_model=True, model_rank=model_rank ) y_pred = predict_result["y_pred"] y_truth = predict_result["y_truth"] metric_score = predict_result["metric_score"] else: y_pred = None y_truth = None metric_score = None return {"y_pred": y_pred, "y_truth": y_truth, "metric_score": metric_score}
def _early_stop_choice(self, wait, min_loss, metric_score, save_handle, patience, epoch): return self.metrics._early_stop_choice(wait, min_loss, metric_score, save_handle, patience, epoch)
[docs] def predict(self, model: Module, dataset: Dataset, loss_terms: Iterable[Callable], dump_dir: str, transform: Transform, epoch: int=1, load_model: bool=False, model_rank: Optional[str]=None, test_mode: bool=False) -> Dict[Literal["y_pred", "y_truth", "val_loss", "metric_score"], Any]: self._set_seed(self.seed) model = model.to(self.device).type(self.dtype) if load_model == True: from ..models import get_pretrain_path pretrain_path = get_pretrain_path(dump_dir, "best", model_rank) self.load_state_dict(model, pretrain_path=pretrain_path, inference=True) dataloader = DataLoader( dataset=dataset, batch_size=self.inference_batch_size, shuffle=False, collate_fn=self.decorate_batch_input, num_workers=self.num_workers, ) model = model.eval() model.test_mode = test_mode batch_bar = tqdm(total=len(dataloader), dynamic_ncols=True, position=0, leave=False, desc='val', ncols=5) val_loss = [] y_preds = defaultdict(list) y_truths = defaultdict(list) for i, batch in enumerate(dataloader): net_input, net_target = self.to_device(batch) output = model(net_input) # Get model outputs if self.monitor is not None: self.monitor.collect(output) loss = 0 with torch.no_grad(): if not load_model: for loss_term in loss_terms.values(): loss += loss_term(output, net_target) val_loss.append(float(loss.data)) y_pred, y_truth = self.decorate_batch_output(output, net_input, net_target) for k, v in y_pred.items(): y_preds[k].extend(v) for k, v in y_truth.items(): y_truths[k].extend(v) if not load_model: batch_bar.set_postfix( Epoch="Epoch {}/{}".format(epoch+1, self.max_epochs), loss="{:.04f}".format(float(np.sum(val_loss) / (i + 1))) ) batch_bar.update() if self.monitor is not None: self.monitor.summary() if transform is not None: transform.inverse_transform(y_preds) transform.inverse_transform(y_truths) metric_score = self.metrics.cal_metric(y_truths, y_preds) if load_model and "_judge_score" in metric_score: metric_score.pop("_judge_score") batch_bar.close() model.test_mode = False return {"y_pred": y_preds, "y_truth": y_truths, "val_loss": val_loss, "metric_score": metric_score}