1 import numpy as np
2 import torch
3 from algorithms.algorithm import Algorithm
4 from scheduler import step_scheduler
5 from utils import update_average
6
7 from gds.common.utils import numel
8
9
10 class GroupAlgorithm(Algorithm):
11 """
12 Parent class for algorithms with group-wise logging.
13 Also handles schedulers.
14 """
15
-
E501
Line too long (106 > 79 characters)
16 def __init__(self, device, grouper, logged_metrics, logged_fields, schedulers, scheduler_metric_names,
17 no_group_logging, **kwargs):
18 """
19 Args:
20 - device: torch device
-
E501
Line too long (82 > 79 characters)
21 - grouper (Grouper): defines groups for which we compute/log stats for
22 - logged_metrics (list of Metric):
23 - logged_fields (list of str):
24 """
25 super().__init__(device)
26 self.grouper = grouper
27 self.group_prefix = 'group_'
28 self.count_field = 'count'
29 self.group_count_field = f'{self.group_prefix}{self.count_field}'
30
31 self.logged_metrics = logged_metrics
32 self.logged_fields = logged_fields
33
34 self.schedulers = schedulers
35 self.scheduler_metric_names = scheduler_metric_names
36 self.no_group_logging = no_group_logging
37
38 def update_log(self, results):
39 """
40 Updates the internal log, Algorithm.log_dict
41 Args:
42 - results (dictionary)
43 """
44 results = self.sanitize_dict(results, to_out_device=False)
45 # check all the fields exist
46 for field in self.logged_fields:
47 assert field in results, f"field {field} missing"
48 # compute statistics for the current batch
49 batch_log = {}
50 with torch.no_grad():
51 for m in self.logged_metrics:
52 if not self.no_group_logging:
-
E501
Line too long (91 > 79 characters)
53 group_metrics, group_counts, worst_group_metric = m.compute_group_wise(
54 results['y_pred'],
55 results['y_true'],
56 results['g'],
57 self.grouper.n_groups,
58 return_dict=False)
59 batch_log[f'{self.group_prefix}{m.name}'] = group_metrics
60 batch_log[m.agg_metric_field] = m.compute(
61 results['y_pred'],
62 results['y_true'],
63 return_dict=False).item()
64 count = numel(results['y_true'])
65
66 # transfer other statistics in the results dictionary
67 for field in self.logged_fields:
68 if field.startswith(self.group_prefix) and self.no_group_logging:
69 continue
70 v = results[field]
71 if isinstance(v, torch.Tensor) and v.numel() == 1:
72 batch_log[field] = v.item()
73 else:
74 if isinstance(v, torch.Tensor):
-
E501
Line too long (154 > 79 characters)
75 assert v.numel() == self.grouper.n_groups, "Current implementation deals only with group-wise statistics or a single-number statistic"
76 assert field.startswith(self.group_prefix)
77 batch_log[field] = v
78
79 # update the log dict with the current batch
-
E501
Line too long (91 > 79 characters)
80 if not self._has_log: # since it is the first log entry, just save the current log
81 self.log_dict = batch_log
82 if not self.no_group_logging:
83 self.log_dict[self.group_count_field] = group_counts
84 self.log_dict[self.count_field] = count
85 else: # take a running average across batches otherwise
86 for k, v in batch_log.items():
87 if k.startswith(self.group_prefix):
88 if self.no_group_logging:
89 continue
-
E501
Line too long (113 > 79 characters)
90 self.log_dict[k] = update_average(self.log_dict[k], self.log_dict[self.group_count_field], v,
91 group_counts)
92 else:
-
E501
Line too long (114 > 79 characters)
93 self.log_dict[k] = update_average(self.log_dict[k], self.log_dict[self.count_field], v, count)
94 if not self.no_group_logging:
95 self.log_dict[self.group_count_field] += group_counts
96 self.log_dict[self.count_field] += count
97 self._has_log = True
98
99 def get_log(self):
100 """
101 Sanitizes the internal log (Algorithm.log_dict) and outputs it.
102 """
103 sanitized_log = {}
104 for k, v in self.log_dict.items():
105 if k.startswith(self.group_prefix):
106 field = k[len(self.group_prefix):]
107 for g in range(self.grouper.n_groups):
108 # set relevant values to NaN depending on the group count
109 count = self.log_dict[self.group_count_field][g].item()
110 if count == 0 and k != self.group_count_field:
111 outval = np.nan
112 else:
113 outval = v[g].item()
114 # add to dictionary with an appropriate name
115 # in practice, it is saving each value as {field}_group:{g}
116 added = False
117 for m in self.logged_metrics:
118 if field == m.name:
119 sanitized_log[m.group_metric_field(g)] = outval
120 added = True
121 if k == self.group_count_field:
122 sanitized_log[self.loss.group_count_field(g)] = outval
123 added = True
124 elif not added:
125 sanitized_log[f'{field}_group:{g}'] = outval
126 else:
127 assert not isinstance(v, torch.Tensor)
128 sanitized_log[k] = v
129 return sanitized_log
130
131 def step_schedulers(self, is_epoch, metrics={}, log_access=False):
132 """
133 Updates the scheduler after an epoch.
-
E501
Line too long (92 > 79 characters)
134 If a scheduler is updated based on a metric (SingleModelAlgorithm.scheduler_metric),
-
E501
Line too long (85 > 79 characters)
135 then it first looks for an entry in metrics_dict and then in its internal log
136 (SingleModelAlgorithm.log_dict) if log_access is True.
137 Args:
138 - metrics_dict (dictionary)
-
E501
Line too long (94 > 79 characters)
139 - log_access (bool): whether the scheduler_metric can be fetched from internal log
140 (self.log_dict)
141 """
-
E501
Line too long (88 > 79 characters)
142 for scheduler, metric_name in zip(self.schedulers, self.scheduler_metric_names):
143 if scheduler is None:
144 continue
145 if is_epoch and scheduler.step_every_batch:
146 continue
147 if (not is_epoch) and (not scheduler.step_every_batch):
148 continue
149 self._step_specific_scheduler(
150 scheduler=scheduler,
151 metric_name=metric_name,
152 metrics=metrics,
153 log_access=log_access)
154
-
E501
Line too long (84 > 79 characters)
155 def _step_specific_scheduler(self, scheduler, metric_name, metrics, log_access):
156 """
157 Helper function for updating scheduler
158 Args:
159 - scheduler: scheduler to update
-
E501
Line too long (92 > 79 characters)
160 - is_epoch (bool): epoch-wise update if set to True, batch-wise update otherwise
-
E501
Line too long (105 > 79 characters)
161 - metric_name (str): name of the metric (key in metrics or log dictionary) to use for updates
-
E501
Line too long (91 > 79 characters)
162 - metrics (dict): a dictionary of metrics that can beused for scheduler updates
-
E501
Line too long (101 > 79 characters)
163 - log_access (bool): whether metrics from self.get_log() can be used to update schedulers
164 """
165 if not scheduler.use_metric or metric_name is None:
166 metric = None
167 elif metric_name in metrics:
168 metric = metrics[metric_name]
169 elif log_access:
170 sanitized_log_dict = self.get_log()
171 if metric_name in sanitized_log_dict:
172 metric = sanitized_log_dict[metric_name]
173 else:
174 raise ValueError('scheduler metric not recognized')
175 else:
176 raise ValueError('scheduler metric not recognized')
177 step_scheduler(scheduler, metric)
178
179 def get_pretty_log_str(self):
180 """
181 Output:
182 - results_str (str)
183 """
184 results_str = ''
185
186 # Get sanitized log dict
187 log = self.get_log()
188
189 # Process aggregate logged fields
190 for field in self.logged_fields:
191 if field.startswith(self.group_prefix):
192 continue
193 results_str += (
194 f'{field}: {log[field]:.3f}\n'
195 )
196
197 # Process aggregate logged metrics
198 for metric in self.logged_metrics:
199 results_str += (
-
E501
Line too long (82 > 79 characters)
200 f'{metric.agg_metric_field}: {log[metric.agg_metric_field]:.3f}\n'
201 )
202
203 # Process logs for each group
204 if not self.no_group_logging:
205 for g in range(self.grouper.n_groups):
206 group_count = log[f"count_group:{g}"]
207 if group_count <= 0:
208 continue
209
210 results_str += (
211 f' {self.grouper.group_str(g)} '
212 f'[n = {group_count:6.0f}]:\t'
213 )
214
215 # Process grouped logged fields
216 for field in self.logged_fields:
217 if field.startswith(self.group_prefix):
218 field_suffix = field[len(self.group_prefix):]
219 log_key = f'{field_suffix}_group:{g}'
220 results_str += (
221 f'{field_suffix}: '
222 f'{log[log_key]:5.3f}\t'
223 )
224
225 # Process grouped metric fields
226 for metric in self.logged_metrics:
227 results_str += (
228 f'{metric.name}: '
229 f'{log[metric.group_metric_field(g)]:5.3f}\t'
230 )
231 results_str += '\n'
232 else:
233 results_str += '\n'
234
235 return results_str