1 import torch
2 from algorithms.single_model_algorithm import SingleModelAlgorithm
3 from models.initializer import initialize_model
4
5
6 class GroupDRO(SingleModelAlgorithm):
7 """
8 Group distributionally robust optimization.
9
10 Original paper:
11 @inproceedings{sagawa2019distributionally,
-
E501
Line too long (142 > 79 characters)
12 title={Distributionally robust neural networks for group shifts: On the importance of regularization for worst-case generalization},
-
E501
Line too long (96 > 79 characters)
13 author={Sagawa, Shiori and Koh, Pang Wei and Hashimoto, Tatsunori B and Liang, Percy},
14 booktitle={International Conference on Learning Representations},
15 year={2019}
17 """
18
-
E501
Line too long (95 > 79 characters)
19 def __init__(self, config, d_out, grouper, loss, metric, n_train_steps, is_group_in_train):
20 # check config
21 assert config.uniform_over_groups
22 # initialize model
23 model = initialize_model(config, d_out).to(config.device)
24 # initialize module
25 super().__init__(
26 config=config,
27 model=model,
28 grouper=grouper,
29 loss=loss,
30 metric=metric,
31 n_train_steps=n_train_steps,
32 )
33 # additional logging
34 self.logged_fields.append('group_weight')
35 # step size
36 self.group_weights_step_size = config.group_dro_step_size
37 # initialize adversarial weights
38 self.group_weights = torch.zeros(grouper.n_groups)
39 self.group_weights[is_group_in_train] = 1
40 self.group_weights = self.group_weights / self.group_weights.sum()
41 self.group_weights = self.group_weights.to(self.device)
42
43 def process_batch(self, batch):
44 """
45 A helper function for update() and evaluate() that processes the batch
46 Args:
47 - batch (tuple of Tensors): a batch of data yielded by data loaders
48 Output:
49 - results (dictionary): information about the batch
50 - g (Tensor)
51 - y_true (Tensor)
52 - metadata (Tensor)
53 - loss (Tensor)
54 - metrics (Tensor)
55 all Tensors are of size (batch_size,)
56 """
57 results = super().process_batch(batch)
58 results['group_weight'] = self.group_weights
59 return results
60
61 def objective(self, results):
62 """
-
E501
Line too long (80 > 79 characters)
63 Takes an output of SingleModelAlgorithm.process_batch() and computes the
-
E501
Line too long (81 > 79 characters)
64 optimized objective. For group DRO, the objective is the weighted average
65 of losses, where groups have weights groupDRO.group_weights.
66 Args:
-
E501
Line too long (82 > 79 characters)
67 - results (dictionary): output of SingleModelAlgorithm.process_batch()
68 Output:
69 - objective (Tensor): optimized objective; size (1,).
70 """
71 group_losses, _, _ = self.loss.compute_group_wise(
72 results['y_pred'],
73 results['y_true'],
74 results['g'],
75 self.grouper.n_groups,
76 return_dict=False)
77 return group_losses @ self.group_weights
78
79 def _update(self, results):
80 """
-
E501
Line too long (94 > 79 characters)
81 Process the batch, update the log, and update the model, group weights, and scheduler.
82 Args:
83 - batch (tuple of Tensors): a batch of data yielded by data loaders
84 Output:
85 - results (dictionary): information about the batch, such as:
86 - g (Tensor)
87 - y_true (Tensor)
88 - metadata (Tensor)
89 - loss (Tensor)
90 - metrics (Tensor)
91 - objective (float)
92 """
93 # compute group losses
94 group_losses, _, _ = self.loss.compute_group_wise(
95 results['y_pred'],
96 results['y_true'],
97 results['g'],
98 self.grouper.n_groups,
99 return_dict=False)
100 # update group weights
-
E501
Line too long (109 > 79 characters)
101 self.group_weights = self.group_weights * torch.exp(self.group_weights_step_size * group_losses.data)
102 self.group_weights = (self.group_weights / (self.group_weights.sum()))
103 # save updated group weights
104 results['group_weight'] = self.group_weights
105 # update model
106 super()._update(results)