Source code for enerzyme.tasks.monitor
from typing import Dict, List
import numpy as np
import torch
from torch import Tensor
from torch_scatter import segment_sum_coo
from ..utils import logger
[docs]
class Monitor:
[docs]
def __init__(self, **terms: Dict[str, List[str]]) -> None:
self.terms = terms
self._reset()
def _reset(self) -> None:
self.collection = {k: [] for k in self.terms}
[docs]
def collect(self, output: Dict[str, Tensor]) -> None:
with torch.no_grad():
for k in self.terms:
if k in output:
self.collection[k].extend(output[k].detach().cpu().numpy())
elif k + "_a" in output:
self.collection[k].extend(segment_sum_coo(output[k + "_a"].detach(), output["batch_seg"]).cpu().numpy())
def _summary(self) -> Dict[str, Dict[str, float]]:
summary_dict = {}
for term, stats in self.terms.items():
summary_dict[term] = {}
for stat in stats:
if stat == "mean":
summary_dict[term][stat] = np.mean(self.collection[term])
if stat == "std":
summary_dict[term][stat] = np.std(self.collection[term])
if stat == "max":
summary_dict[term][stat] = np.max(self.collection[term])
if stat == "min":
summary_dict[term][stat] = np.min(self.collection[term])
return summary_dict
[docs]
def summary(self) -> None:
message = []
summary_dict = self._summary()
for term, stats in summary_dict.items():
message.append(f"-------- {term} ---------")
for stat, value in stats.items():
message.append(f"{stat}: {value}")
logger.info("\n" + "\n".join(message) + f"\n-------------------------")
self._reset()