experiments.algorithms.initializer

experiments/algorithms/initializer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from algorithms.ERM import ERM
from algorithms.IRM import IRM
from algorithms.deepCORAL import DeepCORAL
from algorithms.groupDRO import GroupDRO
from algorithms.FLAG import FLAG
from algorithms.GCL import GCL
from algorithms.GSN import GSN
from algorithms.DANN import DANN, CDANN, DANNG
from algorithms.MLDG import MLDG
from configs.supported import algo_log_metrics
from losses import initialize_loss

from gds.common.utils import get_counts


def initialize_algorithm(config, datasets, full_dataset, train_grouper):
    train_dataset = datasets['train']['dataset']
    train_loader = datasets['train']['loader']

    # Configure the final layer of the networks used
    # The code below are defaults. Edit this if you need special config for your model.
    if train_dataset.is_classification:
        if train_dataset.y_size == 1:
            # For single-task classification, we have one output per class
            d_out = train_dataset.n_classes
        elif train_dataset.y_size is None:
            d_out = train_dataset.n_classes
        elif (train_dataset.y_size > 1) and (train_dataset.n_classes == 2):
            # For multi-task binary classification (each output is the logit for each binary class)
            d_out = train_dataset.y_size
        else:
            raise RuntimeError('d_out not defined.')
    elif train_dataset.is_detection:
        # For detection, d_out is the number of classes
        d_out = train_dataset.n_classes
        if config.algorithm in ['deepCORAL', 'IRM']:
            raise ValueError(f'{config.algorithm} is not currently supported for detection datasets.')
    else:
        # For regression, we have one output per target dimension
        d_out = train_dataset.y_size

    # Other config
    n_train_steps = len(train_loader) * config.n_epochs
    loss = initialize_loss(config, d_out)
    metric = algo_log_metrics[config.algo_log_metric]

    if config.algorithm == 'ERM':
        algorithm = ERM(
            config=config,
            d_out=d_out,
            grouper=train_grouper,
            loss=loss,
            metric=metric,
            n_train_steps=n_train_steps)
    elif config.algorithm == 'groupDRO':
        train_g = train_grouper.metadata_to_group(train_dataset.metadata_array)
        is_group_in_train = get_counts(train_g, train_grouper.n_groups) > 0
        algorithm = GroupDRO(
            config=config,
            d_out=d_out,
            grouper=train_grouper,
            loss=loss,
            metric=metric,
            n_train_steps=n_train_steps,
            is_group_in_train=is_group_in_train)
    elif config.algorithm == 'deepCORAL':
        algorithm = DeepCORAL(
            config=config,
            d_out=d_out,
            grouper=train_grouper,
            loss=loss,
            metric=metric,
            n_train_steps=n_train_steps)
    elif config.algorithm == 'IRM':
        algorithm = IRM(
            config=config,
            d_out=d_out,
            grouper=train_grouper,
            loss=loss,
            metric=metric,
            n_train_steps=n_train_steps)
    elif config.algorithm == 'FLAG':
        algorithm = FLAG(
            config=config,
            d_out=d_out,
            grouper=train_grouper,
            loss=loss,
            metric=metric,
            n_train_steps=n_train_steps)
    elif config.algorithm == 'GCL':
        algorithm = GCL(
            config=config,
            d_out=d_out,
            grouper=train_grouper,
            loss=loss,
            metric=metric,
            n_train_steps=n_train_steps)
    elif config.algorithm == 'GSN':
        algorithm = GSN(
            config=config,
            d_out=d_out,
            grouper=train_grouper,
            loss=loss,
            metric=metric,
            n_train_steps=n_train_steps,
            full_dataset=full_dataset)
    elif config.algorithm == 'DANN':
        algorithm = DANN(
            config=config,
            d_out=d_out,
            grouper=train_grouper,
            loss=loss,
            metric=metric,
            n_train_steps=n_train_steps)
    elif config.algorithm == 'CDANN':
        algorithm = CDANN(
            config=config,
            d_out=d_out,
            grouper=train_grouper,
            loss=loss,
            metric=metric,
            n_train_steps=n_train_steps)
    elif config.algorithm == 'DANN-G':
        algorithm = DANNG(
            config=config,
            d_out=d_out,
            grouper=train_grouper,
            loss=loss,
            metric=metric,
            n_train_steps=n_train_steps)
    elif config.algorithm == 'MLDG':
        algorithm = MLDG(
            config=config,
            d_out=d_out,
            grouper=train_grouper,
            loss=loss,
            metric=metric,
            n_train_steps=n_train_steps)
    else:
        raise ValueError(f"Algorithm {config.algorithm} not recognized")

    return algorithm