1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
|
import torch
from gds.common.utils import avg_over_groups, get_counts, numel
class Metric:
"""
Parent class for metrics.
"""
def __init__(self, name):
self._name = name
def _compute(self, y_pred, y_true):
"""
Helper function for computing the metric.
Subclasses should implement this.
Args:
- y_pred (Tensor): Predicted targets or model output
- y_true (Tensor): True targets
Output:
- metric (0-dim tensor): metric
"""
return NotImplementedError
def worst(self, metrics):
"""
Given a list/numpy array/Tensor of metrics, computes the worst-case metric
Args:
- metrics (Tensor, numpy array, or list): Metrics
Output:
- worst_metric (0-dim tensor): Worst-case metric
"""
raise NotImplementedError
@property
def name(self):
"""
Metric name.
Used to name the key in the results dictionaries returned by the metric.
"""
return self._name
@property
def agg_metric_field(self):
"""
The name of the key in the results dictionary returned by Metric.compute().
This should correspond to the aggregate metric computed on all of y_pred and y_true,
in contrast to a group-wise evaluation.
"""
return f'{self.name}_all'
def group_metric_field(self, group_idx):
"""
The name of the keys corresponding to individual group evaluations
in the results dictionary returned by Metric.compute_group_wise().
"""
return f'{self.name}_group:{group_idx}'
@property
def worst_group_metric_field(self):
"""
The name of the keys corresponding to the worst-group metric
in the results dictionary returned by Metric.compute_group_wise().
"""
return f'{self.name}_wg'
def group_count_field(self, group_idx):
"""
The name of the keys corresponding to each group's count
in the results dictionary returned by Metric.compute_group_wise().
"""
return f'count_group:{group_idx}'
def compute(self, y_pred, y_true, return_dict=True):
"""
Computes metric. This is a wrapper around _compute.
Args:
- y_pred (Tensor): Predicted targets or model output
- y_true (Tensor): True targets
- return_dict (bool): Whether to return the output as a dictionary or a tensor
Output (return_dict=False):
- metric (0-dim tensor): metric. If the inputs are empty, returns tensor(0.)
Output (return_dict=True):
- results (dict): Dictionary of results, mapping metric.agg_metric_field to avg_metric
"""
if numel(y_true) == 0:
agg_metric = torch.tensor(0., device=y_true.device)
else:
agg_metric = self._compute(y_pred, y_true)
if return_dict:
results = {
self.agg_metric_field: agg_metric.item()
}
return results
else:
return agg_metric
def compute_group_wise(self, y_pred, y_true, g, n_groups, return_dict=True):
"""
Computes metrics for each group. This is a wrapper around _compute.
Args:
- y_pred (Tensor): Predicted targets or model output
- y_true (Tensor): True targets
- g (Tensor): groups
- n_groups (int): number of groups
- return_dict (bool): Whether to return the output as a dictionary or a tensor
Output (return_dict=False):
- group_metrics (Tensor): tensor of size (n_groups, ) including the average metric for each group
- group_counts (Tensor): tensor of size (n_groups, ) including the group count
- worst_group_metric (0-dim tensor): worst-group metric
- For empty inputs/groups, corresponding metrics are tensor(0.)
Output (return_dict=True):
- results (dict): Dictionary of results
"""
group_metrics, group_counts, worst_group_metric = self._compute_group_wise(y_pred, y_true, g, n_groups)
if return_dict:
results = {}
for group_idx in range(n_groups):
results[self.group_metric_field(group_idx)] = group_metrics[group_idx].item()
results[self.group_count_field(group_idx)] = group_counts[group_idx].item()
results[self.worst_group_metric_field] = worst_group_metric.item()
return results
else:
return group_metrics, group_counts, worst_group_metric
def _compute_group_wise(self, y_pred, y_true, g, n_groups):
group_metrics = []
group_counts = get_counts(g, n_groups)
for group_idx in range(n_groups):
if group_counts[group_idx] == 0:
group_metrics.append(torch.tensor(0., device=g.device))
else:
group_metrics.append(
self._compute(
y_pred[g == group_idx],
y_true[g == group_idx]))
group_metrics = torch.stack(group_metrics)
worst_group_metric = self.worst(group_metrics[group_counts > 0])
return group_metrics, group_counts, worst_group_metric
class ElementwiseMetric(Metric):
"""
Averages.
"""
def _compute_element_wise(self, y_pred, y_true):
"""
Helper for computing element-wise metric, implemented for each metric
Args:
- y_pred (Tensor): Predicted targets or model output
- y_true (Tensor): True targets
Output:
- element_wise_metrics (Tensor): tensor of size (batch_size, )
"""
raise NotImplementedError
def worst(self, metrics):
"""
Given a list/numpy array/Tensor of metrics, computes the worst-case metric
Args:
- metrics (Tensor, numpy array, or list): Metrics
Output:
- worst_metric (0-dim tensor): Worst-case metric
"""
raise NotImplementedError
def _compute(self, y_pred, y_true):
"""
Helper function for computing the metric.
Args:
- y_pred (Tensor): Predicted targets or model output
- y_true (Tensor): True targets
Output:
- avg_metric (0-dim tensor): average of element-wise metrics
"""
element_wise_metrics = self._compute_element_wise(y_pred, y_true)
avg_metric = element_wise_metrics.mean()
return avg_metric
def _compute_group_wise(self, y_pred, y_true, g, n_groups):
element_wise_metrics = self._compute_element_wise(y_pred, y_true)
group_metrics, group_counts = avg_over_groups(element_wise_metrics, g, n_groups)
worst_group_metric = self.worst(group_metrics[group_counts > 0])
return group_metrics, group_counts, worst_group_metric
@property
def agg_metric_field(self):
"""
The name of the key in the results dictionary returned by Metric.compute().
"""
return f'{self.name}_avg'
def compute_element_wise(self, y_pred, y_true, return_dict=True):
"""
Computes element-wise metric
Args:
- y_pred (Tensor): Predicted targets or model output
- y_true (Tensor): True targets
- return_dict (bool): Whether to return the output as a dictionary or a tensor
Output (return_dict=False):
- element_wise_metrics (Tensor): tensor of size (batch_size, )
Output (return_dict=True):
- results (dict): Dictionary of results, mapping metric.name to element_wise_metrics
"""
element_wise_metrics = self._compute_element_wise(y_pred, y_true)
batch_size = y_pred.size()[0]
assert element_wise_metrics.dim()==1 and element_wise_metrics.numel()==batch_size
if return_dict:
return {self.name: element_wise_metrics}
else:
return element_wise_metrics
def compute_flattened(self, y_pred, y_true, return_dict=True):
flattened_metrics = self.compute_element_wise(y_pred, y_true, return_dict=False)
index = torch.arange(y_true.numel())
if return_dict:
return {self.name: flattened_metrics, 'index': index}
else:
return flattened_metrics, index
class MultiTaskMetric(Metric):
def _compute_flattened(self, flattened_y_pred, flattened_y_true):
raise NotImplementedError
def _compute(self, y_pred, y_true):
flattened_metrics, _ = self.compute_flattened(y_pred, y_true, return_dict=False)
if flattened_metrics.numel() == 0:
return torch.tensor(0., device=y_true.device)
else:
return flattened_metrics.mean()
def _compute_group_wise(self, y_pred, y_true, g, n_groups):
flattened_metrics, indices = self.compute_flattened(y_pred, y_true, return_dict=False)
flattened_g = g[indices]
group_metrics, group_counts = avg_over_groups(flattened_metrics, flattened_g, n_groups)
worst_group_metric = self.worst(group_metrics[group_counts > 0])
return group_metrics, group_counts, worst_group_metric
def compute_flattened(self, y_pred, y_true, return_dict=True):
is_labeled = ~torch.isnan(y_true)
batch_idx = torch.where(is_labeled)[0]
flattened_y_pred = y_pred[is_labeled]
flattened_y_true = y_true[is_labeled]
flattened_metrics = self._compute_flattened(flattened_y_pred, flattened_y_true)
if return_dict:
return {self.name: flattened_metrics, 'index': batch_idx}
else:
return flattened_metrics, batch_idx
|