⬅ common/metrics/all_metrics.py source

1 import numpy as np
2 import torch
3 import torch.nn.functional as F
4 from torchvision.ops.boxes import box_iou
5 from torchvision.models.detection._utils import Matcher
  • E501 Line too long (80 > 79 characters)
6 from gds.common.metrics.metric import Metric, ElementwiseMetric, MultiTaskMetric
7 from gds.common.metrics.loss import ElementwiseLoss
  • F401 'gds.common.utils.avg_over_groups' imported but unused
  • F401 'gds.common.utils.maximum' imported but unused
8 from gds.common.utils import avg_over_groups, minimum, maximum, get_counts
9 import sklearn.metrics
10 from scipy.stats import pearsonr
11  
12  
13 def binary_logits_to_score(logits):
14 assert logits.dim() in (1, 2)
15 if logits.dim() == 2: # multi-class logits
16 assert logits.size(1) == 2, "Only binary classification"
17 score = F.softmax(logits, dim=1)[:, 1]
18  
19 else:
20 score = logits
21 return score
22  
23  
24 def multiclass_logits_to_pred(logits):
25 """
  • E501 Line too long (89 > 79 characters)
26 Takes multi-class logits of size (batch_size, ..., n_classes) and returns predictions
27 by taking an argmax at the last dimension
28 """
29 assert logits.dim() > 1
30  
31 return logits.argmax(-1)
32  
33  
34 def binary_logits_to_pred(logits):
35 return (logits > 0).long()
36  
37  
38 class Accuracy(ElementwiseMetric):
39 def __init__(self, prediction_fn=None, name=None):
40 self.prediction_fn = prediction_fn
41 if name is None:
42 name = 'acc'
43 super().__init__(name=name)
44  
45 def _compute_element_wise(self, y_pred, y_true):
46 if self.prediction_fn is not None:
47 y_pred = self.prediction_fn(y_pred)
48  
49 return (y_pred == y_true).float()
50  
51 def worst(self, metrics):
52 return minimum(metrics)
53  
54  
55 class MultiTaskAccuracy(MultiTaskMetric):
56 def __init__(self, prediction_fn=None, name=None):
57 self.prediction_fn = prediction_fn # should work on flattened inputs
58 if name is None:
59 name = 'acc'
60 super().__init__(name=name)
61  
62 def _compute_flattened(self, flattened_y_pred, flattened_y_true):
63 if self.prediction_fn is not None:
64 flattened_y_pred = self.prediction_fn(flattened_y_pred)
65  
66 return (flattened_y_pred == flattened_y_true).float()
67  
68 def worst(self, metrics):
69 return minimum(metrics)
70  
71  
72 class MultiTaskAveragePrecision(MultiTaskMetric):
73 def __init__(self, prediction_fn=None, name=None, average='macro'):
74 self.prediction_fn = prediction_fn
75 if name is None:
  • F541 F-string is missing placeholders
76 name = f'avgprec'
77 if average is not None:
78 name += f'-{average}'
79 self.average = average
80 super().__init__(name=name)
81  
82 def _compute_flattened(self, flattened_y_pred, flattened_y_true):
83 if self.prediction_fn is not None:
84 flattened_y_pred = self.prediction_fn(flattened_y_pred)
85 ytr = np.array(flattened_y_true.squeeze().detach().cpu().numpy() > 0)
86 ypr = flattened_y_pred.squeeze().detach().cpu().numpy()
87 score = sklearn.metrics.average_precision_score(
88 ytr,
89 ypr,
90 average=self.average
91 )
92 to_ret = torch.tensor(score).to(flattened_y_pred.device)
93 return to_ret
94  
95 def _compute_group_wise(self, y_pred, y_true, g, n_groups):
96 group_metrics = []
97 group_counts = get_counts(g, n_groups)
98 for group_idx in range(n_groups):
99 if group_counts[group_idx] == 0:
100 group_metrics.append(torch.tensor(0., device=g.device))
101 else:
102 flattened_metrics, _ = self.compute_flattened(
103 y_pred[g == group_idx],
104 y_true[g == group_idx],
105 return_dict=False)
106 group_metrics.append(flattened_metrics)
107 group_metrics = torch.stack(group_metrics)
108 worst_group_metric = self.worst(group_metrics[group_counts > 0])
109  
110 return group_metrics, group_counts, worst_group_metric
111  
112 # def _compute(self, y_pred, y_true):
113 # return self._compute_flattened(y_pred, y_true)
114  
115 def worst(self, metrics):
116 return minimum(metrics)
117  
118  
119 class Recall(Metric):
120 def __init__(self, prediction_fn=None, name=None, average='binary'):
121 self.prediction_fn = prediction_fn
122 if name is None:
  • F541 F-string is missing placeholders
123 name = f'recall'
124 if average is not None:
125 name += f'-{average}'
126 self.average = average
127 super().__init__(name=name)
128  
129 def _compute(self, y_pred, y_true):
130 if self.prediction_fn is not None:
131 y_pred = self.prediction_fn(y_pred)
  • E501 Line too long (112 > 79 characters)
132 recall = sklearn.metrics.recall_score(y_true, y_pred, average=self.average, labels=torch.unique(y_true))
133 return torch.tensor(recall)
134  
135 def worst(self, metrics):
136 return minimum(metrics)
137  
138  
139 class F1(Metric):
140 def __init__(self, prediction_fn=None, name=None, average='binary'):
141 self.prediction_fn = prediction_fn
142 if name is None:
  • F541 F-string is missing placeholders
143 name = f'F1'
144 if average is not None:
145 name += f'-{average}'
146 self.average = average
147 super().__init__(name=name)
148  
149 def _compute(self, y_pred, y_true):
150 if self.prediction_fn is not None:
151 y_pred = self.prediction_fn(y_pred)
  • E501 Line too long (107 > 79 characters)
152 score = sklearn.metrics.f1_score(y_true, y_pred, average=self.average, labels=torch.unique(y_true))
153 return torch.tensor(score)
154  
155 def worst(self, metrics):
156 return minimum(metrics)
157  
158  
159 class PearsonCorrelation(Metric):
160 def __init__(self, name=None):
161 if name is None:
162 name = 'r'
163 super().__init__(name=name)
164  
165 def _compute(self, y_pred, y_true):
  • E501 Line too long (105 > 79 characters)
166 r = pearsonr(y_pred.squeeze().detach().cpu().numpy(), y_true.squeeze().detach().cpu().numpy())[0]
167 return torch.tensor(r)
168  
169 def worst(self, metrics):
170 return minimum(metrics)
171  
172  
173 def mse_loss(out, targets):
174 assert out.size() == targets.size()
175 if out.numel() == 0:
176 return torch.Tensor()
177 else:
  • E501 Line too long (85 > 79 characters)
178 assert out.dim() > 1, 'MSE loss currently supports Tensors of dimensions > 1'
179 losses = (out - targets) ** 2
180 reduce_dims = tuple(list(range(1, len(targets.shape))))
181 losses = torch.mean(losses, dim=reduce_dims)
182 return losses
183  
184  
185 class MSE(ElementwiseLoss):
186 def __init__(self, name=None):
187 if name is None:
188 name = 'mse'
189 super().__init__(name=name, loss_fn=mse_loss)
190  
191  
192 class PrecisionAtRecall(Metric):
  • E501 Line too long (82 > 79 characters)
193 """Given a specific model threshold, determine the precision score achieved"""
194  
195 def __init__(self, threshold, score_fn=None, name=None):
196 self.score_fn = score_fn
197 self.threshold = threshold
198 if name is None:
199 name = "precision_at_global_recall"
200 super().__init__(name=name)
201  
202 def _compute(self, y_pred, y_true):
203 score = self.score_fn(y_pred)
204 predictions = (score > self.threshold)
  • E501 Line too long (81 > 79 characters)
205 return torch.tensor(sklearn.metrics.precision_score(y_true, predictions))
206  
207 def worst(self, metrics):
208 return minimum(metrics)
209  
210  
211 class DummyMetric(Metric):
212 """
213 For testing purposes. This Metric always returns -1.
214 """
215  
216 def __init__(self, prediction_fn=None, name=None):
217 self.prediction_fn = prediction_fn
218 if name is None:
219 name = 'dummy'
220 super().__init__(name=name)
221  
222 def _compute(self, y_pred, y_true):
223 return torch.tensor(-1)
224  
225 def _compute_group_wise(self, y_pred, y_true, g, n_groups):
226 group_metrics = torch.ones(n_groups, device=g.device) * -1
227 group_counts = get_counts(g, n_groups)
228 worst_group_metric = self.worst(group_metrics)
229 return group_metrics, group_counts, worst_group_metric
230  
231 def worst(self, metrics):
232 return minimum(metrics)
233  
234  
235 class DetectionAccuracy(ElementwiseMetric):
236 """
237 Given a specific Intersection over union threshold,
238 determine the accuracy achieved for a one-class detector
239 """
240  
241 def __init__(self, iou_threshold=0.5, score_threshold=0.5, name=None):
242 self.iou_threshold = iou_threshold
243 self.score_threshold = score_threshold
244 if name is None:
245 name = "detection_acc"
246 super().__init__(name=name)
247  
248 def _compute_element_wise(self, y_pred, y_true):
249 batch_results = []
250 for src_boxes, target in zip(y_true, y_pred):
251 target_boxes = target["boxes"]
252 target_scores = target["scores"]
253  
254 pred_boxes = target_boxes[target_scores > self.score_threshold]
255 det_accuracy = torch.mean(torch.stack(
  • E501 Line too long (117 > 79 characters)
256 [self._accuracy(src_boxes["boxes"], pred_boxes, iou_thr) for iou_thr in np.arange(0.5, 0.51, 0.05)]))
257 batch_results.append(det_accuracy)
258  
259 return torch.tensor(batch_results)
260  
261 def _accuracy(self, src_boxes, pred_boxes, iou_threshold):
262 total_gt = len(src_boxes)
263 total_pred = len(pred_boxes)
264 if total_gt > 0 and total_pred > 0:
265 # Define the matcher and distance matrix based on iou
266 matcher = Matcher(
267 iou_threshold,
268 iou_threshold,
269 allow_low_quality_matches=False)
270 match_quality_matrix = box_iou(
271 src_boxes,
272 pred_boxes)
273 results = matcher(match_quality_matrix)
274 true_positive = torch.count_nonzero(results.unique() != -1)
275 matched_elements = results[results > -1]
276 # in Matcher, a pred element can be matched only twice
277 false_positive = (
278 torch.count_nonzero(results == -1) +
279 (len(matched_elements) - len(matched_elements.unique()))
280 )
281 false_negative = total_gt - true_positive
  • F841 Local variable 'acc' is assigned to but never used
  • E501 Line too long (83 > 79 characters)
282 acc = true_positive / (true_positive + false_positive + false_negative)
  • E501 Line too long (84 > 79 characters)
283 return true_positive / (true_positive + false_positive + false_negative)
284 elif total_gt == 0:
285 if total_pred > 0:
286 return torch.tensor(0.)
287 else:
288 return torch.tensor(1.)
289 elif total_gt > 0 and total_pred == 0:
290 return torch.tensor(0.)
291  
292 def worst(self, metrics):
293 return minimum(metrics)