1 import torch
2
3 from gds.common.utils import avg_over_groups, get_counts, numel
4
5
6 class Metric:
7 """
8 Parent class for metrics.
9 """
10
11 def __init__(self, name):
12 self._name = name
13
14 def _compute(self, y_pred, y_true):
15 """
16 Helper function for computing the metric.
17 Subclasses should implement this.
18 Args:
19 - y_pred (Tensor): Predicted targets or model output
20 - y_true (Tensor): True targets
21 Output:
22 - metric (0-dim tensor): metric
23 """
24 return NotImplementedError
25
26 def worst(self, metrics):
27 """
-
E501
Line too long (82 > 79 characters)
28 Given a list/numpy array/Tensor of metrics, computes the worst-case metric
29 Args:
30 - metrics (Tensor, numpy array, or list): Metrics
31 Output:
32 - worst_metric (0-dim tensor): Worst-case metric
33 """
34 raise NotImplementedError
35
36 @property
37 def name(self):
38 """
39 Metric name.
-
E501
Line too long (80 > 79 characters)
40 Used to name the key in the results dictionaries returned by the metric.
41 """
42 return self._name
43
44 @property
45 def agg_metric_field(self):
46 """
-
E501
Line too long (83 > 79 characters)
47 The name of the key in the results dictionary returned by Metric.compute().
-
E501
Line too long (92 > 79 characters)
48 This should correspond to the aggregate metric computed on all of y_pred and y_true,
49 in contrast to a group-wise evaluation.
50 """
51 return f'{self.name}_all'
52
53 def group_metric_field(self, group_idx):
54 """
55 The name of the keys corresponding to individual group evaluations
56 in the results dictionary returned by Metric.compute_group_wise().
57 """
58 return f'{self.name}_group:{group_idx}'
59
60 @property
61 def worst_group_metric_field(self):
62 """
63 The name of the keys corresponding to the worst-group metric
64 in the results dictionary returned by Metric.compute_group_wise().
65 """
66 return f'{self.name}_wg'
67
68 def group_count_field(self, group_idx):
69 """
70 The name of the keys corresponding to each group's count
71 in the results dictionary returned by Metric.compute_group_wise().
72 """
73 return f'count_group:{group_idx}'
74
75 def compute(self, y_pred, y_true, return_dict=True):
76 """
77 Computes metric. This is a wrapper around _compute.
78 Args:
79 - y_pred (Tensor): Predicted targets or model output
80 - y_true (Tensor): True targets
-
E501
Line too long (90 > 79 characters)
81 - return_dict (bool): Whether to return the output as a dictionary or a tensor
82 Output (return_dict=False):
-
E501
Line too long (88 > 79 characters)
83 - metric (0-dim tensor): metric. If the inputs are empty, returns tensor(0.)
84 Output (return_dict=True):
-
E501
Line too long (98 > 79 characters)
85 - results (dict): Dictionary of results, mapping metric.agg_metric_field to avg_metric
86 """
87 if numel(y_true) == 0:
88 agg_metric = torch.tensor(0., device=y_true.device)
89 else:
90 agg_metric = self._compute(y_pred, y_true)
91 if return_dict:
92 results = {
93 self.agg_metric_field: agg_metric.item()
94 }
95 return results
96 else:
97 return agg_metric
98
-
E501
Line too long (80 > 79 characters)
99 def compute_group_wise(self, y_pred, y_true, g, n_groups, return_dict=True):
100 """
101 Computes metrics for each group. This is a wrapper around _compute.
102 Args:
103 - y_pred (Tensor): Predicted targets or model output
104 - y_true (Tensor): True targets
105 - g (Tensor): groups
106 - n_groups (int): number of groups
-
E501
Line too long (90 > 79 characters)
107 - return_dict (bool): Whether to return the output as a dictionary or a tensor
108 Output (return_dict=False):
-
E501
Line too long (109 > 79 characters)
109 - group_metrics (Tensor): tensor of size (n_groups, ) including the average metric for each group
-
E501
Line too long (90 > 79 characters)
110 - group_counts (Tensor): tensor of size (n_groups, ) including the group count
111 - worst_group_metric (0-dim tensor): worst-group metric
112 - For empty inputs/groups, corresponding metrics are tensor(0.)
113 Output (return_dict=True):
114 - results (dict): Dictionary of results
115 """
-
E501
Line too long (111 > 79 characters)
116 group_metrics, group_counts, worst_group_metric = self._compute_group_wise(y_pred, y_true, g, n_groups)
117 if return_dict:
118 results = {}
119 for group_idx in range(n_groups):
-
E501
Line too long (93 > 79 characters)
120 results[self.group_metric_field(group_idx)] = group_metrics[group_idx].item()
-
E501
Line too long (91 > 79 characters)
121 results[self.group_count_field(group_idx)] = group_counts[group_idx].item()
122 results[self.worst_group_metric_field] = worst_group_metric.item()
123 return results
124 else:
125 return group_metrics, group_counts, worst_group_metric
126
127 def _compute_group_wise(self, y_pred, y_true, g, n_groups):
128 group_metrics = []
129 group_counts = get_counts(g, n_groups)
130 for group_idx in range(n_groups):
131 if group_counts[group_idx] == 0:
132 group_metrics.append(torch.tensor(0., device=g.device))
133 else:
134 group_metrics.append(
135 self._compute(
136 y_pred[g == group_idx],
137 y_true[g == group_idx]))
138
139 group_metrics = torch.stack(group_metrics)
140 worst_group_metric = self.worst(group_metrics[group_counts > 0])
141
142 return group_metrics, group_counts, worst_group_metric
143
144
145 class ElementwiseMetric(Metric):
146 """
147 Averages.
148 """
149
150 def _compute_element_wise(self, y_pred, y_true):
151 """
152 Helper for computing element-wise metric, implemented for each metric
153 Args:
154 - y_pred (Tensor): Predicted targets or model output
155 - y_true (Tensor): True targets
156 Output:
157 - element_wise_metrics (Tensor): tensor of size (batch_size, )
158 """
159 raise NotImplementedError
160
161 def worst(self, metrics):
162 """
-
E501
Line too long (82 > 79 characters)
163 Given a list/numpy array/Tensor of metrics, computes the worst-case metric
164 Args:
165 - metrics (Tensor, numpy array, or list): Metrics
166 Output:
167 - worst_metric (0-dim tensor): Worst-case metric
168 """
169 raise NotImplementedError
170
171 def _compute(self, y_pred, y_true):
172 """
173 Helper function for computing the metric.
174 Args:
175 - y_pred (Tensor): Predicted targets or model output
176 - y_true (Tensor): True targets
177 Output:
178 - avg_metric (0-dim tensor): average of element-wise metrics
179 """
180 element_wise_metrics = self._compute_element_wise(y_pred, y_true)
181 avg_metric = element_wise_metrics.mean()
182 return avg_metric
183
184 def _compute_group_wise(self, y_pred, y_true, g, n_groups):
185 element_wise_metrics = self._compute_element_wise(y_pred, y_true)
-
E501
Line too long (88 > 79 characters)
186 group_metrics, group_counts = avg_over_groups(element_wise_metrics, g, n_groups)
187 worst_group_metric = self.worst(group_metrics[group_counts > 0])
188 return group_metrics, group_counts, worst_group_metric
189
190 @property
191 def agg_metric_field(self):
192 """
-
E501
Line too long (83 > 79 characters)
193 The name of the key in the results dictionary returned by Metric.compute().
194 """
195 return f'{self.name}_avg'
196
197 def compute_element_wise(self, y_pred, y_true, return_dict=True):
198 """
199 Computes element-wise metric
200 Args:
201 - y_pred (Tensor): Predicted targets or model output
202 - y_true (Tensor): True targets
-
E501
Line too long (90 > 79 characters)
203 - return_dict (bool): Whether to return the output as a dictionary or a tensor
204 Output (return_dict=False):
205 - element_wise_metrics (Tensor): tensor of size (batch_size, )
206 Output (return_dict=True):
-
E501
Line too long (96 > 79 characters)
207 - results (dict): Dictionary of results, mapping metric.name to element_wise_metrics
208 """
209 element_wise_metrics = self._compute_element_wise(y_pred, y_true)
210 batch_size = y_pred.size()[0]
211
-
E225
Missing whitespace around operator (in 2 places)
-
E501
Line too long (89 > 79 characters)
212 assert element_wise_metrics.dim()==1 and element_wise_metrics.numel()==batch_size
213 if return_dict:
214 return {self.name: element_wise_metrics}
215 else:
216 return element_wise_metrics
217
218 def compute_flattened(self, y_pred, y_true, return_dict=True):
-
E501
Line too long (88 > 79 characters)
219 flattened_metrics = self.compute_element_wise(y_pred, y_true, return_dict=False)
220 index = torch.arange(y_true.numel())
221 if return_dict:
222 return {self.name: flattened_metrics, 'index': index}
223 else:
224 return flattened_metrics, index
225
226
227 class MultiTaskMetric(Metric):
228 def _compute_flattened(self, flattened_y_pred, flattened_y_true):
229 raise NotImplementedError
230
231 def _compute(self, y_pred, y_true):
-
E501
Line too long (88 > 79 characters)
232 flattened_metrics, _ = self.compute_flattened(y_pred, y_true, return_dict=False)
233 if flattened_metrics.numel() == 0:
234 return torch.tensor(0., device=y_true.device)
235 else:
236 return flattened_metrics.mean()
237
238 def _compute_group_wise(self, y_pred, y_true, g, n_groups):
-
E501
Line too long (94 > 79 characters)
239 flattened_metrics, indices = self.compute_flattened(y_pred, y_true, return_dict=False)
240 flattened_g = g[indices]
-
E501
Line too long (95 > 79 characters)
241 group_metrics, group_counts = avg_over_groups(flattened_metrics, flattened_g, n_groups)
242 worst_group_metric = self.worst(group_metrics[group_counts > 0])
243 return group_metrics, group_counts, worst_group_metric
244
245 def compute_flattened(self, y_pred, y_true, return_dict=True):
246 is_labeled = ~torch.isnan(y_true)
247 batch_idx = torch.where(is_labeled)[0]
248
249 flattened_y_pred = y_pred[is_labeled]
250 flattened_y_true = y_true[is_labeled]
251
-
E501
Line too long (87 > 79 characters)
252 flattened_metrics = self._compute_flattened(flattened_y_pred, flattened_y_true)
253 if return_dict:
254 return {self.name: flattened_metrics, 'index': batch_idx}
255 else:
256 return flattened_metrics, batch_idx