1 import torch
-
F401
'torch.nn.functional as F' imported but unused
2 import torch.nn.functional as F
3 from algorithms.single_model_algorithm import SingleModelAlgorithm
4 from models.initializer import initialize_model
5 from utils import move_to
6 import copy
7 from gds.common.utils import split_into_groups
8 from torch_geometric.data import Batch
9
-
E302
Expected 2 blank lines, found 1
10 class MLDG(SingleModelAlgorithm):
11 """Domain-Adversarial Neural Networks (abstract class)"""
12
13 # def __init__(self, input_shape, num_classes, num_domains,
14 # hparams, conditional, class_balance):
15 def __init__(self, config, d_out, grouper, loss,
16 metric, n_train_steps):
-
E501
Line too long (92 > 79 characters)
17 model = initialize_model(config, d_out=d_out, is_featurizer=False).to(config.device)
18
19 # initialize module
20 super().__init__(
21 config=config,
22 model=model,
23 grouper=grouper,
24 loss=loss,
25 metric=metric,
26 n_train_steps=n_train_steps,
27 )
28 self.config = config
29
30
-
E303
Too many blank lines (2)
31 def update(self, batch):
32
33 x, y_true, metadata = batch
34
35 x = move_to(x, self.device)
36 y_true = move_to(y_true, self.device)
37 g = move_to(self.grouper.metadata_to_group(metadata), self.device)
38 results = {
39 'g': g,
40 'y_true': y_true,
41 'metadata': metadata,
42 }
43
44 self.optimizer.zero_grad()
45 for p in self.model.parameters():
46 if p.grad is None:
47 p.grad = torch.zeros_like(p)
48
49 unique_groups, group_indices, _ = split_into_groups(results['g'])
50 n_groups_per_batch = unique_groups.numel()
51 rand_group_index = torch.randperm(n_groups_per_batch)
52 group_pairs = []
-
E203
Whitespace before ':'
53 for i in range(rand_group_index.shape[0]) :
54 j = i + 1 if i < rand_group_index.shape[0]-1 else 0
55 group_i, group_j = rand_group_index[i], rand_group_index[j]
-
E501
Line too long (80 > 79 characters)
56 group_pairs.append((group_indices[group_i], group_indices[group_j]))
57
58 objective = 0
59 for (group_indices_i, group_indices_j) in group_pairs:
60 # fine tune clone-network on task "i"
61 inner_net = copy.deepcopy(self.model)
62
63 inner_opt = torch.optim.Adam(
64 inner_net.parameters(),
65 lr=self.config.lr,
66 weight_decay=0
67 )
68
69 inner_obj = self.objective({
70 'y_pred': inner_net(Batch.from_data_list(x[group_indices_i])),
71 'y_true': y_true[group_indices_i]
72 })
73
-
E203
Whitespace before ':'
74 try :
75 inner_opt.zero_grad()
76 inner_obj.backward()
77 inner_opt.step()
78
79 # The network has now accumulated gradients Gi
80 # The clone-network has now parameters P - lr * Gi
81 for p_tgt, p_src in zip(self.model.parameters(),
82 inner_net.parameters()):
83 if p_src.grad is not None:
-
E501
Line too long (82 > 79 characters)
84 p_tgt.grad.data.add_(p_src.grad.data / n_groups_per_batch)
-
E722
Do not use bare 'except'
-
E203
Whitespace before ':'
85 except :
86 print('group_i backward error')
87 pass
88
89 objective += inner_obj.item()
90
91 # this computes Gj on the clone-network
92 loss_inner_j = self.objective({
93 'y_pred': inner_net(Batch.from_data_list(x[group_indices_j])),
94 'y_true': y_true[group_indices_j]
95 })
96 # To deal with pcba, where there could be labels of all nan
-
E203
Whitespace before ':'
97 try :
-
E501
Line too long (107 > 79 characters)
98 grad_inner_j = torch.autograd.grad(loss_inner_j, inner_net.parameters(), allow_unused=True)
-
E722
Do not use bare 'except'
-
E203
Whitespace before ':'
99 except :
100 print('group_j backward error')
101 grad_inner_j = None
102
103 # `objective` is populated for reporting purposes
104 objective += (self.config.mldg_beta * loss_inner_j).item()
105
-
E203
Whitespace before ':'
106 if grad_inner_j is not None :
107 for p, g_j in zip(self.model.parameters(), grad_inner_j):
108 if g_j is not None:
-
E501
Line too long (95 > 79 characters)
109 p.grad.data.add_(self.config.mldg_beta * g_j.data / n_groups_per_batch)
110
111 # The network has now accumulated gradients Gi + beta * Gj
112 # Repeat for all train-test splits, do .step()
113
114 self.optimizer.step()
115 objective /= n_groups_per_batch
116
117 results['objective'] = objective
118 self.step_schedulers(
119 is_epoch=False,
120 metrics=results,
121 log_access=False
122 )
123
124 # log results
125 results['y_pred'] = self.model(x).data
126 self.update_log(results)
127 return self.sanitize_dict(results)
128
129 def objective(self, results):
-
E501
Line too long (89 > 79 characters)
130 return self.loss.compute(results['y_pred'], results['y_true'], return_dict=False)