⬅ common/metrics/loss.py source

1 import torch
2  
  • E501 Line too long (80 > 79 characters)
3 from gds.common.metrics.metric import ElementwiseMetric, Metric, MultiTaskMetric
4 from gds.common.utils import maximum
5  
6  
7 class Loss(Metric):
8 def __init__(self, loss_fn, name=None):
9 self.loss_fn = loss_fn
10 if name is None:
11 name = 'loss'
12 super().__init__(name=name)
13  
14 def _compute(self, y_pred, y_true):
15 """
16 Helper for computing element-wise metric, implemented for each metric
17 Args:
18 - y_pred (Tensor): Predicted targets or model output
19 - y_true (Tensor): True targets
20 Output:
21 - element_wise_metrics (Tensor): tensor of size (batch_size, )
22 """
23  
24 return self.loss_fn(y_pred, y_true)
25  
26 def worst(self, metrics):
27 """
  • E501 Line too long (82 > 79 characters)
28 Given a list/numpy array/Tensor of metrics, computes the worst-case metric
29 Args:
30 - metrics (Tensor, numpy array, or list): Metrics
31 Output:
32 - worst_metric (float): Worst-case metric
33 """
34 return maximum(metrics)
35  
36  
37 class ElementwiseLoss(ElementwiseMetric):
38 def __init__(self, loss_fn, name=None):
39 self.loss_fn = loss_fn
40 if name is None:
41 name = 'loss'
42 super().__init__(name=name)
43  
44 def _compute_element_wise(self, y_pred, y_true):
45 """
46 Helper for computing element-wise metric, implemented for each metric
47 Args:
48 - y_pred (Tensor): Predicted targets or model output
49 - y_true (Tensor): True targets
50 Output:
51 - element_wise_metrics (Tensor): tensor of size (batch_size, )
  • W293 Blank line contains whitespace
52
53 """
  • W293 Blank line contains whitespace
54
55 # import pdb;pdb.set_trace()
  • W293 Blank line contains whitespace
56
57 if isinstance(self.loss_fn, torch.nn.BCEWithLogitsLoss):
58 return self.loss_fn(y_pred.float(), y_true.float()).squeeze(dim=-1)
59 elif isinstance(self.loss_fn, torch.nn.CrossEntropyLoss):
60 return self.loss_fn(y_pred, y_true)
61 else:
62 raise NotImplementedError
63  
64  
  • E303 Too many blank lines (2)
65 def worst(self, metrics):
66 """
  • E501 Line too long (82 > 79 characters)
67 Given a list/numpy array/Tensor of metrics, computes the worst-case metric
68 Args:
69 - metrics (Tensor, numpy array, or list): Metrics
70 Output:
71 - worst_metric (float): Worst-case metric
72 """
73 return maximum(metrics)
74  
75  
76 class MultiTaskLoss(MultiTaskMetric):
77 def __init__(self, loss_fn, name=None):
78 self.loss_fn = loss_fn # should be elementwise
79 if name is None:
80 name = 'loss'
81 super().__init__(name=name)
82  
83 def _compute_flattened(self, flattened_y_pred, flattened_y_true):
84 if isinstance(self.loss_fn, torch.nn.BCEWithLogitsLoss):
  • W293 Blank line contains whitespace
85
86 flattened_y_pred = flattened_y_pred.float()
87 flattened_y_true = flattened_y_true.float()
88 elif isinstance(self.loss_fn, torch.nn.CrossEntropyLoss):
89 flattened_y_true = flattened_y_true.long()
90 flattened_loss = self.loss_fn(flattened_y_pred, flattened_y_true)
91 return flattened_loss
92  
93 def worst(self, metrics):
94 """
  • E501 Line too long (82 > 79 characters)
95 Given a list/numpy array/Tensor of metrics, computes the worst-case metric
96 Args:
97 - metrics (Tensor, numpy array, or list): Metrics
98 Output:
99 - worst_metric (float): Worst-case metric
100 """
101 return maximum(metrics)