1 from algorithms.ERM import ERM
2 from algorithms.IRM import IRM
3 from algorithms.deepCORAL import DeepCORAL
4 from algorithms.groupDRO import GroupDRO
5 from algorithms.FLAG import FLAG
6 from algorithms.GCL import GCL
7 from algorithms.GSN import GSN
8 from algorithms.DANN import DANN, CDANN, DANNG
9 from algorithms.MLDG import MLDG
10 from configs.supported import algo_log_metrics
11 from losses import initialize_loss
12
13 from gds.common.utils import get_counts
14
15
16 def initialize_algorithm(config, datasets, full_dataset, train_grouper):
17 train_dataset = datasets['train']['dataset']
18 train_loader = datasets['train']['loader']
19
20 # Configure the final layer of the networks used
-
E501
Line too long (87 > 79 characters)
21 # The code below are defaults. Edit this if you need special config for your model.
22 if train_dataset.is_classification:
23 if train_dataset.y_size == 1:
24 # For single-task classification, we have one output per class
25 d_out = train_dataset.n_classes
26 elif train_dataset.y_size is None:
27 d_out = train_dataset.n_classes
28 elif (train_dataset.y_size > 1) and (train_dataset.n_classes == 2):
-
E501
Line too long (99 > 79 characters)
29 # For multi-task binary classification (each output is the logit for each binary class)
30 d_out = train_dataset.y_size
31 else:
32 raise RuntimeError('d_out not defined.')
33 elif train_dataset.is_detection:
34 # For detection, d_out is the number of classes
35 d_out = train_dataset.n_classes
36 if config.algorithm in ['deepCORAL', 'IRM']:
-
E501
Line too long (102 > 79 characters)
37 raise ValueError(f'{config.algorithm} is not currently supported for detection datasets.')
38 else:
39 # For regression, we have one output per target dimension
40 d_out = train_dataset.y_size
41
42 # Other config
43 n_train_steps = len(train_loader) * config.n_epochs
44 loss = initialize_loss(config, d_out)
45 metric = algo_log_metrics[config.algo_log_metric]
46
47 if config.algorithm == 'ERM':
48 algorithm = ERM(
49 config=config,
50 d_out=d_out,
51 grouper=train_grouper,
52 loss=loss,
53 metric=metric,
54 n_train_steps=n_train_steps)
55 elif config.algorithm == 'groupDRO':
56 train_g = train_grouper.metadata_to_group(train_dataset.metadata_array)
57 is_group_in_train = get_counts(train_g, train_grouper.n_groups) > 0
58 algorithm = GroupDRO(
59 config=config,
60 d_out=d_out,
61 grouper=train_grouper,
62 loss=loss,
63 metric=metric,
64 n_train_steps=n_train_steps,
65 is_group_in_train=is_group_in_train)
66 elif config.algorithm == 'deepCORAL':
67 algorithm = DeepCORAL(
68 config=config,
69 d_out=d_out,
70 grouper=train_grouper,
71 loss=loss,
72 metric=metric,
73 n_train_steps=n_train_steps)
74 elif config.algorithm == 'IRM':
75 algorithm = IRM(
76 config=config,
77 d_out=d_out,
78 grouper=train_grouper,
79 loss=loss,
80 metric=metric,
81 n_train_steps=n_train_steps)
82 elif config.algorithm == 'FLAG':
83 algorithm = FLAG(
84 config=config,
85 d_out=d_out,
86 grouper=train_grouper,
87 loss=loss,
88 metric=metric,
89 n_train_steps=n_train_steps)
90 elif config.algorithm == 'GCL':
91 algorithm = GCL(
92 config=config,
93 d_out=d_out,
94 grouper=train_grouper,
95 loss=loss,
96 metric=metric,
97 n_train_steps=n_train_steps)
98 elif config.algorithm == 'GSN':
99 algorithm = GSN(
100 config=config,
101 d_out=d_out,
102 grouper=train_grouper,
103 loss=loss,
104 metric=metric,
105 n_train_steps=n_train_steps,
106 full_dataset=full_dataset)
107 elif config.algorithm == 'DANN':
108 algorithm = DANN(
109 config=config,
110 d_out=d_out,
111 grouper=train_grouper,
112 loss=loss,
113 metric=metric,
114 n_train_steps=n_train_steps)
115 elif config.algorithm == 'CDANN':
116 algorithm = CDANN(
117 config=config,
118 d_out=d_out,
119 grouper=train_grouper,
120 loss=loss,
121 metric=metric,
122 n_train_steps=n_train_steps)
123 elif config.algorithm == 'DANN-G':
124 algorithm = DANNG(
125 config=config,
126 d_out=d_out,
127 grouper=train_grouper,
128 loss=loss,
129 metric=metric,
130 n_train_steps=n_train_steps)
131 elif config.algorithm == 'MLDG':
132 algorithm = MLDG(
133 config=config,
134 d_out=d_out,
135 grouper=train_grouper,
136 loss=loss,
137 metric=metric,
138 n_train_steps=n_train_steps)
139 else:
140 raise ValueError(f"Algorithm {config.algorithm} not recognized")
141
142 return algorithm