⬅ algorithms/FLAG.py source

1 from algorithms.single_model_algorithm import SingleModelAlgorithm
2 from models.initializer import initialize_model
3 import torch
4 from torch.nn.utils import clip_grad_norm_
5 from utils import move_to
6  
7  
8 class FLAG(SingleModelAlgorithm):
9 def __init__(self, config, d_out, grouper, loss,
10 metric, n_train_steps):
11 model = initialize_model(config, d_out).to(config.device)
12 # initialize module
13 super().__init__(
14 config=config,
15 model=model,
16 grouper=grouper,
17 loss=loss,
18 metric=metric,
19 n_train_steps=n_train_steps,
20 )
21 self.config = config
22  
23 def update(self, batch):
24 step_size = self.config.flag_step_size
25 m = 3
26  
27 assert self.is_training
28 # process batch
29  
30 x, y_true, metadata = batch
31 x = move_to(x, self.device)
32 y_true = move_to(y_true, self.device)
33 g = move_to(self.grouper.metadata_to_group(metadata), self.device)
34 results = {
35 'g': g,
36 'y_true': y_true,
37 'metadata': metadata,
38 }
39  
40  
  • E303 Too many blank lines (2)
41 perturb_shape = (x.x.shape[0], self.model.emb_dim)
  • E501 Line too long (83 > 79 characters)
42 perturb = torch.FloatTensor(*perturb_shape).uniform_(-step_size, step_size)
43 perturb = move_to(perturb, self.device)
44 perturb.requires_grad_()
45  
46 self.optimizer.zero_grad()
47  
48 outputs = self.model(x, perturb)
49 results['y_pred'] = outputs
50 objective = self.objective(results)
51 objective /= m
52  
53 for _ in range(m - 1):
54 objective.backward()
  • E501 Line too long (89 > 79 characters)
55 perturb_data = perturb.detach() + step_size*torch.sign(perturb.grad.detach())
56 perturb.data = perturb_data.data
57 perturb.grad[:] = 0
58  
59 outputs = self.model(x, perturb)
60 results['y_pred'] = outputs
61 objective = self.objective(results)
62 objective /= m
63  
64 objective.backward()
65 if self.max_grad_norm:
66 clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
67 self.optimizer.step()
68  
69 results['objective'] = objective.item() * m
70 self.step_schedulers(
71 is_epoch=False,
72 metrics=results,
73 log_access=False)
74  
75 # log results
76 self.update_log(results)
77 return self.sanitize_dict(results)
78  
79 def objective(self, results):
  • E501 Line too long (89 > 79 characters)
80 return self.loss.compute(results['y_pred'], results['y_true'], return_dict=False)