1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
|
import numpy as np
import torch
from algorithms.algorithm import Algorithm
from scheduler import step_scheduler
from utils import update_average
from gds.common.utils import numel
class GroupAlgorithm(Algorithm):
"""
Parent class for algorithms with group-wise logging.
Also handles schedulers.
"""
def __init__(self, device, grouper, logged_metrics, logged_fields, schedulers, scheduler_metric_names,
no_group_logging, **kwargs):
"""
Args:
- device: torch device
- grouper (Grouper): defines groups for which we compute/log stats for
- logged_metrics (list of Metric):
- logged_fields (list of str):
"""
super().__init__(device)
self.grouper = grouper
self.group_prefix = 'group_'
self.count_field = 'count'
self.group_count_field = f'{self.group_prefix}{self.count_field}'
self.logged_metrics = logged_metrics
self.logged_fields = logged_fields
self.schedulers = schedulers
self.scheduler_metric_names = scheduler_metric_names
self.no_group_logging = no_group_logging
def update_log(self, results):
"""
Updates the internal log, Algorithm.log_dict
Args:
- results (dictionary)
"""
results = self.sanitize_dict(results, to_out_device=False)
# check all the fields exist
for field in self.logged_fields:
assert field in results, f"field {field} missing"
# compute statistics for the current batch
batch_log = {}
with torch.no_grad():
for m in self.logged_metrics:
if not self.no_group_logging:
group_metrics, group_counts, worst_group_metric = m.compute_group_wise(
results['y_pred'],
results['y_true'],
results['g'],
self.grouper.n_groups,
return_dict=False)
batch_log[f'{self.group_prefix}{m.name}'] = group_metrics
batch_log[m.agg_metric_field] = m.compute(
results['y_pred'],
results['y_true'],
return_dict=False).item()
count = numel(results['y_true'])
# transfer other statistics in the results dictionary
for field in self.logged_fields:
if field.startswith(self.group_prefix) and self.no_group_logging:
continue
v = results[field]
if isinstance(v, torch.Tensor) and v.numel() == 1:
batch_log[field] = v.item()
else:
if isinstance(v, torch.Tensor):
assert v.numel() == self.grouper.n_groups, "Current implementation deals only with group-wise statistics or a single-number statistic"
assert field.startswith(self.group_prefix)
batch_log[field] = v
# update the log dict with the current batch
if not self._has_log: # since it is the first log entry, just save the current log
self.log_dict = batch_log
if not self.no_group_logging:
self.log_dict[self.group_count_field] = group_counts
self.log_dict[self.count_field] = count
else: # take a running average across batches otherwise
for k, v in batch_log.items():
if k.startswith(self.group_prefix):
if self.no_group_logging:
continue
self.log_dict[k] = update_average(self.log_dict[k], self.log_dict[self.group_count_field], v,
group_counts)
else:
self.log_dict[k] = update_average(self.log_dict[k], self.log_dict[self.count_field], v, count)
if not self.no_group_logging:
self.log_dict[self.group_count_field] += group_counts
self.log_dict[self.count_field] += count
self._has_log = True
def get_log(self):
"""
Sanitizes the internal log (Algorithm.log_dict) and outputs it.
"""
sanitized_log = {}
for k, v in self.log_dict.items():
if k.startswith(self.group_prefix):
field = k[len(self.group_prefix):]
for g in range(self.grouper.n_groups):
# set relevant values to NaN depending on the group count
count = self.log_dict[self.group_count_field][g].item()
if count == 0 and k != self.group_count_field:
outval = np.nan
else:
outval = v[g].item()
# add to dictionary with an appropriate name
# in practice, it is saving each value as {field}_group:{g}
added = False
for m in self.logged_metrics:
if field == m.name:
sanitized_log[m.group_metric_field(g)] = outval
added = True
if k == self.group_count_field:
sanitized_log[self.loss.group_count_field(g)] = outval
added = True
elif not added:
sanitized_log[f'{field}_group:{g}'] = outval
else:
assert not isinstance(v, torch.Tensor)
sanitized_log[k] = v
return sanitized_log
def step_schedulers(self, is_epoch, metrics={}, log_access=False):
"""
Updates the scheduler after an epoch.
If a scheduler is updated based on a metric (SingleModelAlgorithm.scheduler_metric),
then it first looks for an entry in metrics_dict and then in its internal log
(SingleModelAlgorithm.log_dict) if log_access is True.
Args:
- metrics_dict (dictionary)
- log_access (bool): whether the scheduler_metric can be fetched from internal log
(self.log_dict)
"""
for scheduler, metric_name in zip(self.schedulers, self.scheduler_metric_names):
if scheduler is None:
continue
if is_epoch and scheduler.step_every_batch:
continue
if (not is_epoch) and (not scheduler.step_every_batch):
continue
self._step_specific_scheduler(
scheduler=scheduler,
metric_name=metric_name,
metrics=metrics,
log_access=log_access)
def _step_specific_scheduler(self, scheduler, metric_name, metrics, log_access):
"""
Helper function for updating scheduler
Args:
- scheduler: scheduler to update
- is_epoch (bool): epoch-wise update if set to True, batch-wise update otherwise
- metric_name (str): name of the metric (key in metrics or log dictionary) to use for updates
- metrics (dict): a dictionary of metrics that can beused for scheduler updates
- log_access (bool): whether metrics from self.get_log() can be used to update schedulers
"""
if not scheduler.use_metric or metric_name is None:
metric = None
elif metric_name in metrics:
metric = metrics[metric_name]
elif log_access:
sanitized_log_dict = self.get_log()
if metric_name in sanitized_log_dict:
metric = sanitized_log_dict[metric_name]
else:
raise ValueError('scheduler metric not recognized')
else:
raise ValueError('scheduler metric not recognized')
step_scheduler(scheduler, metric)
def get_pretty_log_str(self):
"""
Output:
- results_str (str)
"""
results_str = ''
# Get sanitized log dict
log = self.get_log()
# Process aggregate logged fields
for field in self.logged_fields:
if field.startswith(self.group_prefix):
continue
results_str += (
f'{field}: {log[field]:.3f}\n'
)
# Process aggregate logged metrics
for metric in self.logged_metrics:
results_str += (
f'{metric.agg_metric_field}: {log[metric.agg_metric_field]:.3f}\n'
)
# Process logs for each group
if not self.no_group_logging:
for g in range(self.grouper.n_groups):
group_count = log[f"count_group:{g}"]
if group_count <= 0:
continue
results_str += (
f' {self.grouper.group_str(g)} '
f'[n = {group_count:6.0f}]:\t'
)
# Process grouped logged fields
for field in self.logged_fields:
if field.startswith(self.group_prefix):
field_suffix = field[len(self.group_prefix):]
log_key = f'{field_suffix}_group:{g}'
results_str += (
f'{field_suffix}: '
f'{log[log_key]:5.3f}\t'
)
# Process grouped metric fields
for metric in self.logged_metrics:
results_str += (
f'{metric.name}: '
f'{log[metric.group_metric_field(g)]:5.3f}\t'
)
results_str += '\n'
else:
results_str += '\n'
return results_str
|