1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
|
import torch
import torch.nn.functional as F
from algorithms.single_model_algorithm import SingleModelAlgorithm
from models.initializer import initialize_model
from utils import move_to
import copy
from gds.common.utils import split_into_groups
from torch_geometric.data import Batch
class MLDG(SingleModelAlgorithm):
"""Domain-Adversarial Neural Networks (abstract class)"""
# def __init__(self, input_shape, num_classes, num_domains,
# hparams, conditional, class_balance):
def __init__(self, config, d_out, grouper, loss,
metric, n_train_steps):
model = initialize_model(config, d_out=d_out, is_featurizer=False).to(config.device)
# initialize module
super().__init__(
config=config,
model=model,
grouper=grouper,
loss=loss,
metric=metric,
n_train_steps=n_train_steps,
)
self.config = config
def update(self, batch):
x, y_true, metadata = batch
x = move_to(x, self.device)
y_true = move_to(y_true, self.device)
g = move_to(self.grouper.metadata_to_group(metadata), self.device)
results = {
'g': g,
'y_true': y_true,
'metadata': metadata,
}
self.optimizer.zero_grad()
for p in self.model.parameters():
if p.grad is None:
p.grad = torch.zeros_like(p)
unique_groups, group_indices, _ = split_into_groups(results['g'])
n_groups_per_batch = unique_groups.numel()
rand_group_index = torch.randperm(n_groups_per_batch)
group_pairs = []
for i in range(rand_group_index.shape[0]) :
j = i + 1 if i < rand_group_index.shape[0]-1 else 0
group_i, group_j = rand_group_index[i], rand_group_index[j]
group_pairs.append((group_indices[group_i], group_indices[group_j]))
objective = 0
for (group_indices_i, group_indices_j) in group_pairs:
# fine tune clone-network on task "i"
inner_net = copy.deepcopy(self.model)
inner_opt = torch.optim.Adam(
inner_net.parameters(),
lr=self.config.lr,
weight_decay=0
)
inner_obj = self.objective({
'y_pred': inner_net(Batch.from_data_list(x[group_indices_i])),
'y_true': y_true[group_indices_i]
})
try :
inner_opt.zero_grad()
inner_obj.backward()
inner_opt.step()
# The network has now accumulated gradients Gi
# The clone-network has now parameters P - lr * Gi
for p_tgt, p_src in zip(self.model.parameters(),
inner_net.parameters()):
if p_src.grad is not None:
p_tgt.grad.data.add_(p_src.grad.data / n_groups_per_batch)
except :
print('group_i backward error')
pass
objective += inner_obj.item()
# this computes Gj on the clone-network
loss_inner_j = self.objective({
'y_pred': inner_net(Batch.from_data_list(x[group_indices_j])),
'y_true': y_true[group_indices_j]
})
# To deal with pcba, where there could be labels of all nan
try :
grad_inner_j = torch.autograd.grad(loss_inner_j, inner_net.parameters(), allow_unused=True)
except :
print('group_j backward error')
grad_inner_j = None
# `objective` is populated for reporting purposes
objective += (self.config.mldg_beta * loss_inner_j).item()
if grad_inner_j is not None :
for p, g_j in zip(self.model.parameters(), grad_inner_j):
if g_j is not None:
p.grad.data.add_(self.config.mldg_beta * g_j.data / n_groups_per_batch)
# The network has now accumulated gradients Gi + beta * Gj
# Repeat for all train-test splits, do .step()
self.optimizer.step()
objective /= n_groups_per_batch
results['objective'] = objective
self.step_schedulers(
is_epoch=False,
metrics=results,
log_access=False
)
# log results
results['y_pred'] = self.model(x).data
self.update_log(results)
return self.sanitize_dict(results)
def objective(self, results):
return self.loss.compute(results['y_pred'], results['y_true'], return_dict=False)
|