1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
|
from algorithms.ERM import ERM
from algorithms.IRM import IRM
from algorithms.deepCORAL import DeepCORAL
from algorithms.groupDRO import GroupDRO
from algorithms.FLAG import FLAG
from algorithms.GCL import GCL
from algorithms.GSN import GSN
from algorithms.DANN import DANN, CDANN, DANNG
from algorithms.MLDG import MLDG
from configs.supported import algo_log_metrics
from losses import initialize_loss
from gds.common.utils import get_counts
def initialize_algorithm(config, datasets, full_dataset, train_grouper):
train_dataset = datasets['train']['dataset']
train_loader = datasets['train']['loader']
# Configure the final layer of the networks used
# The code below are defaults. Edit this if you need special config for your model.
if train_dataset.is_classification:
if train_dataset.y_size == 1:
# For single-task classification, we have one output per class
d_out = train_dataset.n_classes
elif train_dataset.y_size is None:
d_out = train_dataset.n_classes
elif (train_dataset.y_size > 1) and (train_dataset.n_classes == 2):
# For multi-task binary classification (each output is the logit for each binary class)
d_out = train_dataset.y_size
else:
raise RuntimeError('d_out not defined.')
elif train_dataset.is_detection:
# For detection, d_out is the number of classes
d_out = train_dataset.n_classes
if config.algorithm in ['deepCORAL', 'IRM']:
raise ValueError(f'{config.algorithm} is not currently supported for detection datasets.')
else:
# For regression, we have one output per target dimension
d_out = train_dataset.y_size
# Other config
n_train_steps = len(train_loader) * config.n_epochs
loss = initialize_loss(config, d_out)
metric = algo_log_metrics[config.algo_log_metric]
if config.algorithm == 'ERM':
algorithm = ERM(
config=config,
d_out=d_out,
grouper=train_grouper,
loss=loss,
metric=metric,
n_train_steps=n_train_steps)
elif config.algorithm == 'groupDRO':
train_g = train_grouper.metadata_to_group(train_dataset.metadata_array)
is_group_in_train = get_counts(train_g, train_grouper.n_groups) > 0
algorithm = GroupDRO(
config=config,
d_out=d_out,
grouper=train_grouper,
loss=loss,
metric=metric,
n_train_steps=n_train_steps,
is_group_in_train=is_group_in_train)
elif config.algorithm == 'deepCORAL':
algorithm = DeepCORAL(
config=config,
d_out=d_out,
grouper=train_grouper,
loss=loss,
metric=metric,
n_train_steps=n_train_steps)
elif config.algorithm == 'IRM':
algorithm = IRM(
config=config,
d_out=d_out,
grouper=train_grouper,
loss=loss,
metric=metric,
n_train_steps=n_train_steps)
elif config.algorithm == 'FLAG':
algorithm = FLAG(
config=config,
d_out=d_out,
grouper=train_grouper,
loss=loss,
metric=metric,
n_train_steps=n_train_steps)
elif config.algorithm == 'GCL':
algorithm = GCL(
config=config,
d_out=d_out,
grouper=train_grouper,
loss=loss,
metric=metric,
n_train_steps=n_train_steps)
elif config.algorithm == 'GSN':
algorithm = GSN(
config=config,
d_out=d_out,
grouper=train_grouper,
loss=loss,
metric=metric,
n_train_steps=n_train_steps,
full_dataset=full_dataset)
elif config.algorithm == 'DANN':
algorithm = DANN(
config=config,
d_out=d_out,
grouper=train_grouper,
loss=loss,
metric=metric,
n_train_steps=n_train_steps)
elif config.algorithm == 'CDANN':
algorithm = CDANN(
config=config,
d_out=d_out,
grouper=train_grouper,
loss=loss,
metric=metric,
n_train_steps=n_train_steps)
elif config.algorithm == 'DANN-G':
algorithm = DANNG(
config=config,
d_out=d_out,
grouper=train_grouper,
loss=loss,
metric=metric,
n_train_steps=n_train_steps)
elif config.algorithm == 'MLDG':
algorithm = MLDG(
config=config,
d_out=d_out,
grouper=train_grouper,
loss=loss,
metric=metric,
n_train_steps=n_train_steps)
else:
raise ValueError(f"Algorithm {config.algorithm} not recognized")
return algorithm
|