experiments.losses

experiments/losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch.nn as nn

from gds.common.metrics.all_metrics import MSE
from gds.common.metrics.loss import ElementwiseLoss, Loss, MultiTaskLoss


def initialize_loss(config, d_out):
    if config.loss_function == 'cross_entropy':
        return ElementwiseLoss(loss_fn=nn.CrossEntropyLoss(reduction='none'))

    if config.loss_function == 'BCEWithLogitsLoss':
        return ElementwiseLoss(loss_fn=nn.BCEWithLogitsLoss(reduction='none'))

    elif config.loss_function == 'lm_cross_entropy':
        return MultiTaskLoss(loss_fn=nn.CrossEntropyLoss(reduction='none'))

    elif config.loss_function == 'mse':
        return MSE(name='loss')

    elif config.loss_function == 'multitask_bce':
        return MultiTaskLoss(loss_fn=nn.BCEWithLogitsLoss(reduction='none'))

    elif config.loss_function == 'fasterrcnn_criterion':
        from models.detection.fasterrcnn import FasterRCNNLoss
        return ElementwiseLoss(loss_fn=FasterRCNNLoss(config.device))

    else:
        raise ValueError(f'config.loss_function {config.loss_function} not recognized')