1 import torch
            
            2 from algorithms.single_model_algorithm import SingleModelAlgorithm
            
            3 from models.initializer import initialize_model
            
            4  
            
            5  
            
            6 class GroupDRO(SingleModelAlgorithm):
            
            7     """
            
            8     Group distributionally robust optimization.
            
            9  
            
            10     Original paper:
            
            11         @inproceedings{sagawa2019distributionally,
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (142 > 79 characters)
 
               
               
12           title={Distributionally robust neural networks for group shifts: On the importance of regularization for worst-case generalization},
             
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (96 > 79 characters)
 
               
               
13           author={Sagawa, Shiori and Koh, Pang Wei and Hashimoto, Tatsunori B and Liang, Percy},
             
            14           booktitle={International Conference on Learning Representations},
            
            15           year={2019}
            
            
            17     """
            
            18  
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (95 > 79 characters)
 
               
               
19     def __init__(self, config, d_out, grouper, loss, metric, n_train_steps, is_group_in_train):
             
            20         # check config
            
            21         assert config.uniform_over_groups
            
            22         # initialize model
            
            23         model = initialize_model(config, d_out).to(config.device)
            
            24         # initialize module
            
            25         super().__init__(
            
            26             config=config,
            
            27             model=model,
            
            28             grouper=grouper,
            
            29             loss=loss,
            
            30             metric=metric,
            
            31             n_train_steps=n_train_steps,
            
            32         )
            
            33         # additional logging
            
            34         self.logged_fields.append('group_weight')
            
            35         # step size
            
            36         self.group_weights_step_size = config.group_dro_step_size
            
            37         # initialize adversarial weights
            
            38         self.group_weights = torch.zeros(grouper.n_groups)
            
            39         self.group_weights[is_group_in_train] = 1
            
            40         self.group_weights = self.group_weights / self.group_weights.sum()
            
            41         self.group_weights = self.group_weights.to(self.device)
            
            42  
            
            43     def process_batch(self, batch):
            
            44         """
            
            45         A helper function for update() and evaluate() that processes the batch
            
            46         Args:
            
            47             - batch (tuple of Tensors): a batch of data yielded by data loaders
            
            48         Output:
            
            49             - results (dictionary): information about the batch
            
            50                 - g (Tensor)
            
            51                 - y_true (Tensor)
            
            52                 - metadata (Tensor)
            
            53                 - loss (Tensor)
            
            54                 - metrics (Tensor)
            
            55               all Tensors are of size (batch_size,)
            
            56         """
            
            57         results = super().process_batch(batch)
            
            58         results['group_weight'] = self.group_weights
            
            59         return results
            
            60  
            
            61     def objective(self, results):
            
            62         """
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (80 > 79 characters)
 
               
               
63         Takes an output of SingleModelAlgorithm.process_batch() and computes the
             
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (81 > 79 characters)
 
               
               
64         optimized objective. For group DRO, the objective is the weighted average
             
            65         of losses, where groups have weights groupDRO.group_weights.
            
            66         Args:
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (82 > 79 characters)
 
               
               
67             - results (dictionary): output of SingleModelAlgorithm.process_batch()
             
            68         Output:
            
            69             - objective (Tensor): optimized objective; size (1,).
            
            70         """
            
            71         group_losses, _, _ = self.loss.compute_group_wise(
            
            72             results['y_pred'],
            
            73             results['y_true'],
            
            74             results['g'],
            
            75             self.grouper.n_groups,
            
            76             return_dict=False)
            
            77         return group_losses @ self.group_weights
            
            78  
            
            79     def _update(self, results):
            
            80         """
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (94 > 79 characters)
 
               
               
81         Process the batch, update the log, and update the model, group weights, and scheduler.
             
            82         Args:
            
            83             - batch (tuple of Tensors): a batch of data yielded by data loaders
            
            84         Output:
            
            85             - results (dictionary): information about the batch, such as:
            
            86                 - g (Tensor)
            
            87                 - y_true (Tensor)
            
            88                 - metadata (Tensor)
            
            89                 - loss (Tensor)
            
            90                 - metrics (Tensor)
            
            91                 - objective (float)
            
            92         """
            
            93         # compute group losses
            
            94         group_losses, _, _ = self.loss.compute_group_wise(
            
            95             results['y_pred'],
            
            96             results['y_true'],
            
            97             results['g'],
            
            98             self.grouper.n_groups,
            
            99             return_dict=False)
            
            100         # update group weights
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (109 > 79 characters)
 
               
               
101         self.group_weights = self.group_weights * torch.exp(self.group_weights_step_size * group_losses.data)
             
            102         self.group_weights = (self.group_weights / (self.group_weights.sum()))
            
            103         # save updated group weights
            
            104         results['group_weight'] = self.group_weights
            
            105         # update model
            
            106         super()._update(results)