1 from algorithms.single_model_algorithm import SingleModelAlgorithm
2 from models.initializer import initialize_model
3 import torch
4 from torch.nn.utils import clip_grad_norm_
5 from utils import move_to
6
7
8 class FLAG(SingleModelAlgorithm):
9 def __init__(self, config, d_out, grouper, loss,
10 metric, n_train_steps):
11 model = initialize_model(config, d_out).to(config.device)
12 # initialize module
13 super().__init__(
14 config=config,
15 model=model,
16 grouper=grouper,
17 loss=loss,
18 metric=metric,
19 n_train_steps=n_train_steps,
20 )
21 self.config = config
22
23 def update(self, batch):
24 step_size = self.config.flag_step_size
25 m = 3
26
27 assert self.is_training
28 # process batch
29
30 x, y_true, metadata = batch
31 x = move_to(x, self.device)
32 y_true = move_to(y_true, self.device)
33 g = move_to(self.grouper.metadata_to_group(metadata), self.device)
34 results = {
35 'g': g,
36 'y_true': y_true,
37 'metadata': metadata,
38 }
39
40
-
E303
Too many blank lines (2)
41 perturb_shape = (x.x.shape[0], self.model.emb_dim)
-
E501
Line too long (83 > 79 characters)
42 perturb = torch.FloatTensor(*perturb_shape).uniform_(-step_size, step_size)
43 perturb = move_to(perturb, self.device)
44 perturb.requires_grad_()
45
46 self.optimizer.zero_grad()
47
48 outputs = self.model(x, perturb)
49 results['y_pred'] = outputs
50 objective = self.objective(results)
51 objective /= m
52
53 for _ in range(m - 1):
54 objective.backward()
-
E501
Line too long (89 > 79 characters)
55 perturb_data = perturb.detach() + step_size*torch.sign(perturb.grad.detach())
56 perturb.data = perturb_data.data
57 perturb.grad[:] = 0
58
59 outputs = self.model(x, perturb)
60 results['y_pred'] = outputs
61 objective = self.objective(results)
62 objective /= m
63
64 objective.backward()
65 if self.max_grad_norm:
66 clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
67 self.optimizer.step()
68
69 results['objective'] = objective.item() * m
70 self.step_schedulers(
71 is_epoch=False,
72 metrics=results,
73 log_access=False)
74
75 # log results
76 self.update_log(results)
77 return self.sanitize_dict(results)
78
79 def objective(self, results):
-
E501
Line too long (89 > 79 characters)
80 return self.loss.compute(results['y_pred'], results['y_true'], return_dict=False)