1 import torch
            
            2  
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (80 > 79 characters)
3 from gds.common.metrics.metric import ElementwiseMetric, Metric, MultiTaskMetric 
            4 from gds.common.utils import maximum
            
            5  
            
            6  
            
            7 class Loss(Metric):
            
            8     def __init__(self, loss_fn, name=None):
            
            9         self.loss_fn = loss_fn
            
            10         if name is None:
            
            11             name = 'loss'
            
            12         super().__init__(name=name)
            
            13  
            
            14     def _compute(self, y_pred, y_true):
            
            15         """
            
            16         Helper for computing element-wise metric, implemented for each metric
            
            17         Args:
            
            18             - y_pred (Tensor): Predicted targets or model output
            
            19             - y_true (Tensor): True targets
            
            20         Output:
            
            21             - element_wise_metrics (Tensor): tensor of size (batch_size, )
            
            22         """
            
            23  
            
            24         return self.loss_fn(y_pred, y_true)
            
            25  
            
            26     def worst(self, metrics):
            
            27         """
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (82 > 79 characters)
28         Given a list/numpy array/Tensor of metrics, computes the worst-case metric 
            29         Args:
            
            30             - metrics (Tensor, numpy array, or list): Metrics
            
            31         Output:
            
            32             - worst_metric (float): Worst-case metric
            
            33         """
            
            34         return maximum(metrics)
            
            35  
            
            36  
            
            37 class ElementwiseLoss(ElementwiseMetric):
            
            38     def __init__(self, loss_fn, name=None):
            
            39         self.loss_fn = loss_fn
            
            40         if name is None:
            
            41             name = 'loss'
            
            42         super().__init__(name=name)
            
            43  
            
            44     def _compute_element_wise(self, y_pred, y_true):
            
            45         """
            
            46         Helper for computing element-wise metric, implemented for each metric
            
            47         Args:
            
            48             - y_pred (Tensor): Predicted targets or model output
            
            49             - y_true (Tensor): True targets
            
            50         Output:
            
            51             - element_wise_metrics (Tensor): tensor of size (batch_size, )
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
52 
            53         """
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
54 
            55         # import pdb;pdb.set_trace()
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
56 
            57         if isinstance(self.loss_fn, torch.nn.BCEWithLogitsLoss):
            
            58             return self.loss_fn(y_pred.float(), y_true.float()).squeeze(dim=-1)
            
            59         elif isinstance(self.loss_fn, torch.nn.CrossEntropyLoss):
            
            60             return self.loss_fn(y_pred, y_true)
            
            61         else:
            
            62             raise NotImplementedError
            
            63  
            
            64  
            
            
               
               
                  - 
                     
                        E303
                     
                     Too many blank lines (2)
65     def worst(self, metrics): 
            66         """
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (82 > 79 characters)
67         Given a list/numpy array/Tensor of metrics, computes the worst-case metric 
            68         Args:
            
            69             - metrics (Tensor, numpy array, or list): Metrics
            
            70         Output:
            
            71             - worst_metric (float): Worst-case metric
            
            72         """
            
            73         return maximum(metrics)
            
            74  
            
            75  
            
            76 class MultiTaskLoss(MultiTaskMetric):
            
            77     def __init__(self, loss_fn, name=None):
            
            78         self.loss_fn = loss_fn  # should be elementwise
            
            79         if name is None:
            
            80             name = 'loss'
            
            81         super().__init__(name=name)
            
            82  
            
            83     def _compute_flattened(self, flattened_y_pred, flattened_y_true):
            
            84         if isinstance(self.loss_fn, torch.nn.BCEWithLogitsLoss):
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
85 
            86             flattened_y_pred = flattened_y_pred.float()
            
            87             flattened_y_true = flattened_y_true.float()
            
            88         elif isinstance(self.loss_fn, torch.nn.CrossEntropyLoss):
            
            89             flattened_y_true = flattened_y_true.long()
            
            90         flattened_loss = self.loss_fn(flattened_y_pred, flattened_y_true)
            
            91         return flattened_loss
            
            92  
            
            93     def worst(self, metrics):
            
            94         """
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (82 > 79 characters)
95         Given a list/numpy array/Tensor of metrics, computes the worst-case metric 
            96         Args:
            
            97             - metrics (Tensor, numpy array, or list): Metrics
            
            98         Output:
            
            99             - worst_metric (float): Worst-case metric
            
            100         """
            
            101         return maximum(metrics)