⬅ algorithms/IRM.py source

1 import torch
2 import torch.autograd as autograd
3 from algorithms.single_model_algorithm import SingleModelAlgorithm
4 from models.initializer import initialize_model
5 from optimizer import initialize_optimizer
6  
7 from gds.common.metrics.metric import ElementwiseMetric, MultiTaskMetric
8 from gds.common.utils import split_into_groups
9  
10  
11 class IRM(SingleModelAlgorithm):
12 """
13 Invariant risk minimization.
14  
15 Original paper:
16 @article{arjovsky2019invariant,
17 title={Invariant risk minimization},
  • E501 Line too long (100 > 79 characters)
18 author={Arjovsky, Martin and Bottou, L{\'e}on and Gulrajani, Ishaan and Lopez-Paz, David},
19 journal={arXiv preprint arXiv:1907.02893},
20 year={2019}
21 }
22  
23 The IRM penalty function below is adapted from the code snippet
24 provided in the above paper.
25 """
26  
27 def __init__(self, config, d_out, grouper, loss, metric, n_train_steps):
28 """
29 Algorithm-specific arguments (in config):
30 - irm_lambda
31 - irm_penalty_anneal_iters
32 """
33 # check config
34 assert config.train_loader == 'group'
35 assert config.uniform_over_groups
36 assert config.distinct_groups
37 # initialize model
38 model = initialize_model(config, d_out).to(config.device)
39 # initialize the module
40 super().__init__(
41 config=config,
42 model=model,
43 grouper=grouper,
44 loss=loss,
45 metric=metric,
46 n_train_steps=n_train_steps,
47 )
48  
49 # additional logging
50 self.logged_fields.append('penalty')
51 # set IRM-specific variables
52 self.irm_lambda = config.irm_lambda
53 self.irm_penalty_anneal_iters = config.irm_penalty_anneal_iters
54 self.scale = torch.tensor(1.).to(self.device).requires_grad_()
55 self.update_count = 0
  • E501 Line too long (97 > 79 characters)
56 self.config = config # Need to store config for IRM because we need to re-init optimizer
57  
  • E501 Line too long (97 > 79 characters)
58 assert isinstance(self.loss, ElementwiseMetric) or isinstance(self.loss, MultiTaskMetric)
59  
60 def irm_penalty(self, losses):
  • E501 Line too long (87 > 79 characters)
61 grad_1 = autograd.grad(losses[0::2].mean(), [self.scale], create_graph=True)[0]
  • E501 Line too long (87 > 79 characters)
62 grad_2 = autograd.grad(losses[1::2].mean(), [self.scale], create_graph=True)[0]
63 result = torch.sum(grad_1 * grad_2)
64 return result
65  
66 def objective(self, results):
67 # Compute penalty on each group
68 # To be consistent with the DomainBed implementation,
  • E501 Line too long (91 > 79 characters)
69 # this returns the average loss and penalty across groups, regardless of group size
  • E501 Line too long (91 > 79 characters)
70 # But the GroupLoader ensures that each group is of the same size in each minibatch
71 unique_groups, group_indices, _ = split_into_groups(results['g'])
72 n_groups_per_batch = unique_groups.numel()
73 avg_loss = 0.
74 penalty = 0.
75  
  • E501 Line too long (91 > 79 characters)
76 for i_group in group_indices: # Each element of group_indices is a list of indices
77 group_losses, _ = self.loss.compute_flattened(
78 self.scale * results['y_pred'][i_group],
79 results['y_true'][i_group],
80 return_dict=False)
81 if group_losses.numel() > 0:
82 avg_loss += group_losses.mean()
83 if self.is_training: # Penalties only make sense when training
84 penalty += self.irm_penalty(group_losses)
85 avg_loss /= n_groups_per_batch
86 penalty /= n_groups_per_batch
87  
88 if self.update_count >= self.irm_penalty_anneal_iters:
89 penalty_weight = self.irm_lambda
90 else:
91 penalty_weight = 1.0
92  
93 # Package the results
94 if isinstance(penalty, torch.Tensor):
95 results['penalty'] = penalty.item()
96 else:
97 results['penalty'] = penalty
98  
99 return avg_loss + penalty * penalty_weight
100  
101 def _update(self, results):
102 if self.update_count == self.irm_penalty_anneal_iters:
103 print('Hit IRM penalty anneal iters')
104 # Reset optimizer to deal with the changing penalty weight
105 self.optimizer = initialize_optimizer(self.config, self.model)
106 super()._update(results)
107 self.update_count += 1