⬅ losses.py source

1 import torch.nn as nn
2  
3 from gds.common.metrics.all_metrics import MSE
  • F401 'gds.common.metrics.loss.Loss' imported but unused
4 from gds.common.metrics.loss import ElementwiseLoss, Loss, MultiTaskLoss
5  
6  
7 def initialize_loss(config, d_out):
8 if config.loss_function == 'cross_entropy':
9 return ElementwiseLoss(loss_fn=nn.CrossEntropyLoss(reduction='none'))
10  
11 if config.loss_function == 'BCEWithLogitsLoss':
12 return ElementwiseLoss(loss_fn=nn.BCEWithLogitsLoss(reduction='none'))
13  
14 elif config.loss_function == 'lm_cross_entropy':
15 return MultiTaskLoss(loss_fn=nn.CrossEntropyLoss(reduction='none'))
16  
17 elif config.loss_function == 'mse':
18 return MSE(name='loss')
19  
20 elif config.loss_function == 'multitask_bce':
21 return MultiTaskLoss(loss_fn=nn.BCEWithLogitsLoss(reduction='none'))
22  
23 elif config.loss_function == 'fasterrcnn_criterion':
24 from models.detection.fasterrcnn import FasterRCNNLoss
25 return ElementwiseLoss(loss_fn=FasterRCNNLoss(config.device))
26  
27 else:
  • E501 Line too long (87 > 79 characters)
28 raise ValueError(f'config.loss_function {config.loss_function} not recognized')