from abc import ABC, abstractmethod
import math
from typing import Dict
import torch
from torch.nn import MSELoss as MSELoss_
from torch.nn import L1Loss
[docs]
class WeightedLoss(ABC):
[docs]
def __init__(self, **weights: Dict[str, float]) -> None:
self.weights = weights
[docs]
@abstractmethod
def loss_fn(self, output: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor], k: str) -> torch.Tensor:
...
def __call__(self, output: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor]) -> torch.Tensor:
loss = 0
for k, v in self.weights.items():
if output[k].dim() == target[k].dim() + 1:
target[k] = target[k].unsqueeze(-1).expand_as(output[k])
loss_term = self.loss_fn(output, target, k)
target[k] = target[k].narrow(-1, 0, 1).squeeze(-1)
else:
loss_term = self.loss_fn(output, target, k)
loss = loss + v * loss_term
return loss
[docs]
class MSELoss(WeightedLoss):
[docs]
def __init__(self, **weights: Dict[str, float]) -> None:
super().__init__(**weights)
self.mseloss = MSELoss_()
[docs]
def loss_fn(self, output: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor], k: str) -> torch.Tensor:
return self.mseloss(output[k], target[k])
[docs]
class RMSELoss(WeightedLoss):
[docs]
def __init__(self, **weights: Dict[str, float]) -> None:
super().__init__(**weights)
self.mseloss = MSELoss_()
[docs]
def loss_fn(self, output: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor], k: str) -> torch.Tensor:
return torch.sqrt(self.mseloss(output[k], target[k]))
[docs]
class MAELoss(WeightedLoss):
[docs]
def __init__(self, **weights: Dict[str, float]) -> None:
super().__init__(**weights)
self.maeloss = L1Loss()
[docs]
def loss_fn(self, output: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor], k: str) -> torch.Tensor:
return self.maeloss(output[k], target[k])
[docs]
class NLLLoss(WeightedLoss):
[docs]
def __init__(self, eps: float = 1e-6, **weights: Dict[str, float]) -> None:
super().__init__(**weights)
self.eps = eps
[docs]
def loss_fn(self, output: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor], k: str) -> torch.Tensor:
return 0.5 * torch.mean(torch.log(torch.clamp(output[k + "_var"], self.eps, 1)) + (output[k] - target[k]) ** 2 / torch.clamp(output[k + "_var"], self.eps, 1))
[docs]
class NLLLossVarOnly(WeightedLoss):
[docs]
def __init__(self, eps: float = 1e-6, **weights: Dict[str, float]) -> None:
super().__init__(**weights)
self.eps = eps
[docs]
def loss_fn(self, output: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor], k: str) -> torch.Tensor:
with torch.no_grad():
mse = (output[k] - target[k]) ** 2
return 0.5 * torch.mean(torch.log(output[k + "_var"] + self.eps) + mse / (output[k + "_var"] + self.eps))
[docs]
class CRPSLoss(WeightedLoss):
[docs]
def __init__(self, eps: float = 1e-6, **weights: Dict[str, float]) -> None:
from torch.distributions.normal import Normal
super().__init__(**weights)
self.normal = Normal(0, 1)
self.eps = eps
self.phi = lambda x: self.normal.log_prob(x).exp()
self.Phi = self.normal.cdf
self.sqrt_pi = math.sqrt(math.pi)
[docs]
def loss_fn(self, output: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor], k: str) -> torch.Tensor:
if k + "_std" in output:
std = output[k + "_std"] + self.eps
elif k + "_var" in output:
std = torch.sqrt(output[k + "_var"] + self.eps)
dev = (output[k] - target[k]) / std
return torch.mean(std * (dev * (2 * self.Phi(dev) - 1) + 2 * self.phi(dev) - 1 / self.sqrt_pi))
[docs]
class L2Penalty:
[docs]
def __init__(self, weight: float) -> None:
self.weight = weight
def __call__(self, output: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor]) -> torch.Tensor:
return output.get("l2_penalty", 0) * self.weight
LOSS_REGISTER = {
"mae": MAELoss,
"mse": MSELoss,
"rmse": RMSELoss,
"nll": NLLLoss,
"nll_var_only": NLLLossVarOnly,
"l2_penalty": L2Penalty,
"crps": CRPSLoss
}