⬅ algorithms/DANN.py source

1 import copy
2  
3 import torch
4 import torch.nn.functional as F
5 from algorithms.single_model_algorithm import SingleModelAlgorithm
6 from models.initializer import initialize_model
7 from utils import move_to
8 from models.gnn import GNN_node
9 from torch_geometric.nn import global_mean_pool
10  
  • E302 Expected 2 blank lines, found 1
11 class AbstractDANN(SingleModelAlgorithm):
12 """Domain-Adversarial Neural Networks (abstract class)"""
13  
14 # def __init__(self, input_shape, num_classes, num_train_domains,
15 # hparams, conditional, class_balance):
16 def __init__(self, config, d_out, grouper, loss,
17 metric, n_train_steps, conditional, class_balance):
  • E501 Line too long (90 > 79 characters)
18 featurizer, classifier = initialize_model(config, d_out=d_out, is_featurizer=True)
19 featurizer = featurizer.to(config.device)
20 classifier = classifier.to(config.device)
21 model = torch.nn.Sequential(featurizer, classifier).to(config.device)
22  
23 # initialize module
24 super().__init__(
25 config=config,
26 model=model,
27 grouper=grouper,
28 loss=loss,
29 metric=metric,
30 n_train_steps=n_train_steps,
31 )
  • E261 At least two spaces before inline comment
  • E501 Line too long (85 > 79 characters)
32 assert config.num_train_domains <= 1000 # domain space shouldn't be too large
33  
34 self.featurizer = featurizer
35 self.classifier = classifier
36 self.register_buffer('update_count', torch.tensor([0]))
37  
38  
39 ##############################################
40  
41 self.hparams_lambda = config.dann_lambda
42 self.conditional = conditional
43 self.class_balance = class_balance
44 num_classes = d_out
45 emb_dim = self.featurizer.d_out
  • E501 Line too long (83 > 79 characters)
46 self.discriminator = torch.nn.Sequential(torch.nn.Linear(emb_dim, emb_dim),
47 torch.nn.ReLU(),
  • E501 Line too long (83 > 79 characters)
48 torch.nn.Linear(emb_dim, emb_dim),
49 torch.nn.ReLU(),
  • E501 Line too long (118 > 79 characters)
50 torch.nn.Linear(emb_dim, config.num_train_domains)).to(config.device)
51 self.class_embeddings = torch.nn.Embedding(num_classes,
  • E128 Continuation line under-indented for visual indent
52 self.featurizer.d_out).to(config.device)
53  
54 # Optimizers
55 self.disc_opt = torch.optim.Adam(
56 (list(self.discriminator.parameters()) +
57 list(self.class_embeddings.parameters())),
58 lr=config.lr,
59 weight_decay=0,
60 betas=(0.5, 0.9))
61  
62 self.gen_opt = torch.optim.Adam(
63 (list(self.featurizer.parameters()) +
64 list(self.classifier.parameters())),
65 lr=config.lr,
66 weight_decay=0,
67 betas=(0.5, 0.9))
68  
69 def process_batch(self, batch):
70 """
71 Override
72 """
73 # forward pass
74 x, y_true, metadata = batch
75 x = x.to(self.device)
76 y_true = y_true.to(self.device)
77 g = self.grouper.metadata_to_group(metadata).to(self.device)
78 features = self.featurizer(x)
79 outputs = self.classifier(features)
80  
81 # package the results
82 results = {
83 'g': g,
84 'y_true': y_true,
85 'y_pred': outputs,
86 'metadata': metadata,
87 'features': features,
88 }
89 return results
90  
91 def update(self, batch):
92  
93 x, y_true, metadata = batch
94 x = move_to(x, self.device)
95 y_true = move_to(y_true, self.device)
96 g = move_to(self.grouper.metadata_to_group(metadata), self.device)
97 results = {
98 'g': g,
99 'y_true': y_true,
100 'metadata': metadata,
101 }
102  
103 self.update_count += 1
104 z = self.featurizer(x)
105 if self.conditional:
106 disc_input = z + self.class_embeddings(y_true)
107 else:
108 disc_input = z
109 disc_out = self.discriminator(disc_input)
110  
111 # should be the domain label
  • E231 Missing whitespace after ','
112 disc_labels = move_to(metadata[:,0].flatten(), self.device)
113  
114 if self.class_balance:
115 y_counts = F.one_hot(y_true).sum(dim=0)
116 weights = 1. / (y_counts[y_true] * y_counts.shape[0]).float()
  • E501 Line too long (80 > 79 characters)
117 disc_loss = F.cross_entropy(disc_out, disc_labels, reduction='none')
118 disc_loss = (weights * disc_loss).sum()
119 else:
120 disc_loss = F.cross_entropy(disc_out, disc_labels)
121  
122 disc_softmax = F.softmax(disc_out, dim=1)
123 input_grad = torch.autograd.grad(disc_softmax[:, disc_labels].sum(),
  • E128 Continuation line under-indented for visual indent
124 [disc_input], create_graph=True)[0]
125 grad_penalty = (input_grad**2).sum(dim=1).mean(dim=0)
126 hparams_gra_penalty = 0
127 disc_loss += hparams_gra_penalty * grad_penalty
128  
  • F841 Local variable 'hparam_d_steps_per_g_step' is assigned to but never used
129 d_steps_per_g = hparam_d_steps_per_g_step = 1
130 all_preds = self.classifier(z)
131 results['y_pred'] = all_preds
132 classifier_loss = self.objective(results)
133  
134 if (self.update_count.item() % (1+d_steps_per_g) < d_steps_per_g):
135 self.disc_opt.zero_grad()
136 disc_loss.backward()
137 self.disc_opt.step()
138 else:
139 gen_loss = (classifier_loss +
140 (self.hparams_lambda * -disc_loss))
141 self.disc_opt.zero_grad()
142 self.gen_opt.zero_grad()
143 gen_loss.backward()
144 self.gen_opt.step()
145  
146 results['objective'] = classifier_loss.item()
147 self.step_schedulers(
148 is_epoch=False,
149 metrics=results,
150 log_access=False)
151  
152 # log results
153 self.update_log(results)
154 return self.sanitize_dict(results)
155  
156 def objective(self, results):
  • E501 Line too long (89 > 79 characters)
157 return self.loss.compute(results['y_pred'], results['y_true'], return_dict=False)
158  
  • E302 Expected 2 blank lines, found 1
159 class OurAbstractDANN(SingleModelAlgorithm):
160 """Domain-Adversarial Neural Networks (abstract class)"""
161  
162 # def __init__(self, input_shape, num_classes, num_train_domains,
163 # hparams, conditional, class_balance):
164 def __init__(self, config, d_out, grouper, loss,
165 metric, n_train_steps):
  • E501 Line too long (115 > 79 characters)
166 featurizer, pooler, classifier = initialize_model(config, d_out=d_out, is_featurizer=True, is_pooled=False)
167 featurizer = featurizer.to(config.device)
168 classifier = classifier.to(config.device)
169  
  • E261 At least two spaces before inline comment
170 model = classifier # fake, useless
171  
172 # initialize module
173 super().__init__(
174 config=config,
175 model=model,
176 grouper=grouper,
177 loss=loss,
178 metric=metric,
179 n_train_steps=n_train_steps,
180 )
  • E261 At least two spaces before inline comment
  • E501 Line too long (85 > 79 characters)
181 assert config.num_train_domains <= 1000 # domain space shouldn't be too large
182  
183 self.featurizer = featurizer
184 self.classifier = classifier
185 self.pooler = pooler
186 self.register_buffer('update_count', torch.tensor([0]))
187  
188  
189 ##############################################
190 self.hparams_lambda = config.dann_lambda
  • F841 Local variable 'num_classes' is assigned to but never used
191 num_classes = d_out
192 emb_dim = self.featurizer.d_out
193 # GNN type fixed at GIN for the discriminator, layer num fixed at 2
  • E501 Line too long (99 > 79 characters)
194 self.discriminator_gnn = GNN_node(num_layer=2, emb_dim=emb_dim, dropout=0, batchnorm=False,
  • E501 Line too long (111 > 79 characters)
195 dataset_group=config.model_kwargs['dataset_group']).to(config.device)
196 self.discriminator_gnn.destroy_node_encoder()
197 self.discriminator_pool = global_mean_pool
  • E501 Line too long (101 > 79 characters)
198 self.discriminator_mlp = torch.nn.Linear(emb_dim, config.num_train_domains).to(config.device)
199 # self.discriminator_mlp = torch.nn.Sequential(
200 # torch.nn.Linear(emb_dim, emb_dim),
  • E501 Line too long (85 > 79 characters)
201 # torch.nn.BatchNorm1d(emb_dim), torch.nn.ReLU(),
  • E501 Line too long (88 > 79 characters)
202 # torch.nn.Linear(emb_dim, config.num_train_domains)
203 # ).to(config.device)
204  
205 # Optimizers
206 self.disc_opt = torch.optim.Adam(
207 (list(self.discriminator_gnn.parameters()) +
208 list(self.discriminator_mlp.parameters())),
209 lr=config.lr,
210 weight_decay=0,
211 betas=(0.5, 0.9))
212  
213 self.gen_opt = torch.optim.Adam(
214 (list(self.featurizer.parameters()) +
215 list(self.classifier.parameters())),
216 lr=config.lr,
217 weight_decay=0,
218 betas=(0.5, 0.9))
219  
220 def process_batch(self, batch):
221 """
222 Override
223 """
224 # forward pass
225 x, y_true, metadata = batch
226 x = x.to(self.device)
227 y_true = y_true.to(self.device)
228 g = self.grouper.metadata_to_group(metadata).to(self.device)
229 features = self.pooler(*self.featurizer(x))
230 outputs = self.classifier(features)
231  
232 # package the results
233 results = {
234 'g': g,
235 'y_true': y_true,
236 'y_pred': outputs,
237 'metadata': metadata,
238 'features': features,
239 }
240 return results
241  
242 def update(self, batch):
243  
244 x, y_true, metadata = batch
245 x = move_to(x, self.device)
246 y_true = move_to(y_true, self.device)
247 g = move_to(self.grouper.metadata_to_group(metadata), self.device)
248 results = {
249 'g': g,
250 'y_true': y_true,
251 'metadata': metadata,
252 }
253  
254 self.update_count += 1
255 disc_input = z = self.featurizer(x)
256 disc_x = copy.deepcopy(x)
257 disc_x.x = disc_input[0]
258 disc_out = self.discriminator_gnn(disc_x)
259 disc_out = self.discriminator_pool(disc_out, disc_input[1])
260 disc_out = self.discriminator_mlp(disc_out)
261  
262 # should be the domain label
  • E231 Missing whitespace after ','
263 disc_labels = move_to(metadata[:,0].flatten(), self.device)
264 disc_loss = F.cross_entropy(disc_out, disc_labels)
265  
  • F841 Local variable 'hparam_d_steps_per_g_step' is assigned to but never used
266 d_steps_per_g = hparam_d_steps_per_g_step = 1
267 all_preds = self.classifier(self.pooler(*z))
268 results['y_pred'] = all_preds
269 classifier_loss = self.objective(results)
270  
271 if (self.update_count.item() % (1+d_steps_per_g) < d_steps_per_g):
272 self.disc_opt.zero_grad()
273 disc_loss.backward()
274 self.disc_opt.step()
275 else:
276 gen_loss = (classifier_loss +
277 (self.hparams_lambda * -disc_loss))
278 self.disc_opt.zero_grad()
279 self.gen_opt.zero_grad()
280 gen_loss.backward()
281 self.gen_opt.step()
282  
283 results['objective'] = classifier_loss.item()
284 self.step_schedulers(
285 is_epoch=False,
286 metrics=results,
287 log_access=False)
288  
289 # log results
290 self.update_log(results)
291 return self.sanitize_dict(results)
292  
293 def objective(self, results):
  • E501 Line too long (89 > 79 characters)
294 return self.loss.compute(results['y_pred'], results['y_true'], return_dict=False)
295  
  • E302 Expected 2 blank lines, found 1
296 class DANN(AbstractDANN):
297 """Unconditional DANN"""
298 def __init__(self, config, d_out, grouper, loss,
299 metric, n_train_steps):
300 super(DANN, self).__init__(config, d_out, grouper, loss,
  • E128 Continuation line under-indented for visual indent
301 metric, n_train_steps, conditional=False, class_balance=False)
302  
303  
304 class CDANN(AbstractDANN):
305 """Conditional DANN"""
306 def __init__(self, config, d_out, grouper, loss,
307 metric, n_train_steps):
308 super(CDANN, self).__init__(config, d_out, grouper, loss,
  • E128 Continuation line under-indented for visual indent
309 metric, n_train_steps, conditional=True, class_balance=True)
310  
  • E302 Expected 2 blank lines, found 1
311 class DANNG(OurAbstractDANN):
312 """Conditional DANN"""
313 def __init__(self, config, d_out, grouper, loss,
314 metric, n_train_steps):
315 super(DANNG, self).__init__(config, d_out, grouper, loss,
  • E128 Continuation line under-indented for visual indent
  • W292 No newline at end of file
316 metric, n_train_steps)