Source code for enerzyme.models.physnet.loss
from typing import Dict
from torch import Tensor
[docs]
class NHLoss:
def __call__(self, output: Dict[str, Tensor], target: Dict[str, Tensor]) -> Tensor:
return output.get("nh_loss", 0) * self.weight
LOSS_REGISTER = {
"nh_penalty": NHLoss,
}