experiments.algorithms.group_algorithm

experiments/algorithms/group_algorithm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import numpy as np
import torch
from algorithms.algorithm import Algorithm
from scheduler import step_scheduler
from utils import update_average

from gds.common.utils import numel


class GroupAlgorithm(Algorithm):
    """
    Parent class for algorithms with group-wise logging.
    Also handles schedulers.
    """

    def __init__(self, device, grouper, logged_metrics, logged_fields, schedulers, scheduler_metric_names,
                 no_group_logging, **kwargs):
        """
        Args:
            - device: torch device
            - grouper (Grouper): defines groups for which we compute/log stats for
            - logged_metrics (list of Metric):
            - logged_fields (list of str):
        """
        super().__init__(device)
        self.grouper = grouper
        self.group_prefix = 'group_'
        self.count_field = 'count'
        self.group_count_field = f'{self.group_prefix}{self.count_field}'

        self.logged_metrics = logged_metrics
        self.logged_fields = logged_fields

        self.schedulers = schedulers
        self.scheduler_metric_names = scheduler_metric_names
        self.no_group_logging = no_group_logging

    def update_log(self, results):
        """
        Updates the internal log, Algorithm.log_dict
        Args:
            - results (dictionary)
        """
        results = self.sanitize_dict(results, to_out_device=False)
        # check all the fields exist
        for field in self.logged_fields:
            assert field in results, f"field {field} missing"
        # compute statistics for the current batch
        batch_log = {}
        with torch.no_grad():
            for m in self.logged_metrics:
                if not self.no_group_logging:
                    group_metrics, group_counts, worst_group_metric = m.compute_group_wise(
                        results['y_pred'],
                        results['y_true'],
                        results['g'],
                        self.grouper.n_groups,
                        return_dict=False)
                    batch_log[f'{self.group_prefix}{m.name}'] = group_metrics
                batch_log[m.agg_metric_field] = m.compute(
                    results['y_pred'],
                    results['y_true'],
                    return_dict=False).item()
            count = numel(results['y_true'])

        # transfer other statistics in the results dictionary
        for field in self.logged_fields:
            if field.startswith(self.group_prefix) and self.no_group_logging:
                continue
            v = results[field]
            if isinstance(v, torch.Tensor) and v.numel() == 1:
                batch_log[field] = v.item()
            else:
                if isinstance(v, torch.Tensor):
                    assert v.numel() == self.grouper.n_groups, "Current implementation deals only with group-wise statistics or a single-number statistic"
                    assert field.startswith(self.group_prefix)
                batch_log[field] = v

        # update the log dict with the current batch
        if not self._has_log:  # since it is the first log entry, just save the current log
            self.log_dict = batch_log
            if not self.no_group_logging:
                self.log_dict[self.group_count_field] = group_counts
            self.log_dict[self.count_field] = count
        else:  # take a running average across batches otherwise
            for k, v in batch_log.items():
                if k.startswith(self.group_prefix):
                    if self.no_group_logging:
                        continue
                    self.log_dict[k] = update_average(self.log_dict[k], self.log_dict[self.group_count_field], v,
                                                      group_counts)
                else:
                    self.log_dict[k] = update_average(self.log_dict[k], self.log_dict[self.count_field], v, count)
            if not self.no_group_logging:
                self.log_dict[self.group_count_field] += group_counts
            self.log_dict[self.count_field] += count
        self._has_log = True

    def get_log(self):
        """
        Sanitizes the internal log (Algorithm.log_dict) and outputs it.
        """
        sanitized_log = {}
        for k, v in self.log_dict.items():
            if k.startswith(self.group_prefix):
                field = k[len(self.group_prefix):]
                for g in range(self.grouper.n_groups):
                    # set relevant values to NaN depending on the group count
                    count = self.log_dict[self.group_count_field][g].item()
                    if count == 0 and k != self.group_count_field:
                        outval = np.nan
                    else:
                        outval = v[g].item()
                    # add to dictionary with an appropriate name
                    # in practice, it is saving each value as {field}_group:{g}
                    added = False
                    for m in self.logged_metrics:
                        if field == m.name:
                            sanitized_log[m.group_metric_field(g)] = outval
                            added = True
                    if k == self.group_count_field:
                        sanitized_log[self.loss.group_count_field(g)] = outval
                        added = True
                    elif not added:
                        sanitized_log[f'{field}_group:{g}'] = outval
            else:
                assert not isinstance(v, torch.Tensor)
                sanitized_log[k] = v
        return sanitized_log

    def step_schedulers(self, is_epoch, metrics={}, log_access=False):
        """
        Updates the scheduler after an epoch.
        If a scheduler is updated based on a metric (SingleModelAlgorithm.scheduler_metric),
        then it first looks for an entry in metrics_dict and then in its internal log
        (SingleModelAlgorithm.log_dict) if log_access is True.
        Args:
            - metrics_dict (dictionary)
            - log_access (bool): whether the scheduler_metric can be fetched from internal log
                                 (self.log_dict)
        """
        for scheduler, metric_name in zip(self.schedulers, self.scheduler_metric_names):
            if scheduler is None:
                continue
            if is_epoch and scheduler.step_every_batch:
                continue
            if (not is_epoch) and (not scheduler.step_every_batch):
                continue
            self._step_specific_scheduler(
                scheduler=scheduler,
                metric_name=metric_name,
                metrics=metrics,
                log_access=log_access)

    def _step_specific_scheduler(self, scheduler, metric_name, metrics, log_access):
        """
        Helper function for updating scheduler
        Args:
            - scheduler: scheduler to update
            - is_epoch (bool): epoch-wise update if set to True, batch-wise update otherwise
            - metric_name (str): name of the metric (key in metrics or log dictionary) to use for updates
            - metrics (dict): a dictionary of metrics that can beused for scheduler updates
            - log_access (bool): whether metrics from self.get_log() can be used to update schedulers
        """
        if not scheduler.use_metric or metric_name is None:
            metric = None
        elif metric_name in metrics:
            metric = metrics[metric_name]
        elif log_access:
            sanitized_log_dict = self.get_log()
            if metric_name in sanitized_log_dict:
                metric = sanitized_log_dict[metric_name]
            else:
                raise ValueError('scheduler metric not recognized')
        else:
            raise ValueError('scheduler metric not recognized')
        step_scheduler(scheduler, metric)

    def get_pretty_log_str(self):
        """
        Output:
            - results_str (str)
        """
        results_str = ''

        # Get sanitized log dict
        log = self.get_log()

        # Process aggregate logged fields
        for field in self.logged_fields:
            if field.startswith(self.group_prefix):
                continue
            results_str += (
                f'{field}: {log[field]:.3f}\n'
            )

        # Process aggregate logged metrics
        for metric in self.logged_metrics:
            results_str += (
                f'{metric.agg_metric_field}: {log[metric.agg_metric_field]:.3f}\n'
            )

        # Process logs for each group
        if not self.no_group_logging:
            for g in range(self.grouper.n_groups):
                group_count = log[f"count_group:{g}"]
                if group_count <= 0:
                    continue

                results_str += (
                    f'  {self.grouper.group_str(g)}  '
                    f'[n = {group_count:6.0f}]:\t'
                )

                # Process grouped logged fields
                for field in self.logged_fields:
                    if field.startswith(self.group_prefix):
                        field_suffix = field[len(self.group_prefix):]
                        log_key = f'{field_suffix}_group:{g}'
                        results_str += (
                            f'{field_suffix}: '
                            f'{log[log_key]:5.3f}\t'
                        )

                # Process grouped metric fields
                for metric in self.logged_metrics:
                    results_str += (
                        f'{metric.name}: '
                        f'{log[metric.group_metric_field(g)]:5.3f}\t'
                    )
                results_str += '\n'
        else:
            results_str += '\n'

        return results_str