1 from algorithms.group_algorithm import GroupAlgorithm
2 from optimizer import initialize_optimizer
3 from scheduler import initialize_scheduler
4 from torch.nn.utils import clip_grad_norm_
5 from utils import move_to
6
7
8 class SingleModelAlgorithm(GroupAlgorithm):
9 """
10 An abstract class for algorithm that has one underlying model.
11 """
12
13 def __init__(self, config, model, grouper, loss, metric, n_train_steps):
14 # get metrics
15 self.loss = loss
16 logged_metrics = [self.loss, ]
17 if metric is not None:
18 self.metric = metric
19 logged_metrics.append(self.metric)
20 else:
21 self.metric = None
22 # initialize models, optimizers, and schedulers
23 self.optimizer = initialize_optimizer(config, model)
24 self.max_grad_norm = config.max_grad_norm
25 scheduler = initialize_scheduler(config, self.optimizer, n_train_steps)
26 # initialize the module
27 super().__init__(
28 device=config.device,
29 grouper=grouper,
30 logged_metrics=logged_metrics,
31 logged_fields=['objective'],
32 schedulers=[scheduler, ],
33 scheduler_metric_names=[config.scheduler_metric_name, ],
34 no_group_logging=config.no_group_logging,
35 )
36 self.model = model
37
38 def process_batch(self, batch):
39 """
40 A helper function for update() and evaluate() that processes the batch
41 Args:
42 - batch (tuple of Tensors): a batch of data yielded by data loaders
43 Output:
44 - results (dictionary): information about the batch
45 - y_true (Tensor)
46 - g (Tensor)
47 - metadata (Tensor)
48 - output (Tensor)
49 - y_true
50 """
51 x, y_true, metadata = batch
52 x = move_to(x, self.device)
53 y_true = move_to(y_true, self.device)
54 g = move_to(self.grouper.metadata_to_group(metadata), self.device)
55
56 if self.model.needs_y:
57 if self.training:
58 outputs = self.model(x, y_true)
59 else:
60 outputs = self.model(x, None)
61 else:
62 outputs = self.model(x)
-
W293
Blank line contains whitespace
63
64 results = {
65 'g': g,
66 'y_true': y_true,
67 'y_pred': outputs,
68 'metadata': metadata,
69 }
70 return results
71
72 def objective(self, results):
73 raise NotImplementedError
74
75 def evaluate(self, batch):
76 """
77 Process the batch and update the log, without updating the model
78 Args:
79 - batch (tuple of Tensors): a batch of data yielded by data loaders
80 Output:
81 - results (dictionary): information about the batch, such as:
82 - g (Tensor)
83 - y_true (Tensor)
84 - metadata (Tensor)
85 - outputs (Tensor)
86 - y_pred (Tensor)
87 - objective (float)
88 """
89 assert not self.is_training
90 results = self.process_batch(batch)
91
92 results['objective'] = self.objective(results).item()
93
94 self.update_log(results)
95 return self.sanitize_dict(results)
96
97 def update(self, batch):
98 """
99 Process the batch, update the log, and update the model
100 Args:
101 - batch (tuple of Tensors): a batch of data yielded by data loaders
102 Output:
103 - results (dictionary): information about the batch, such as:
104 - g (Tensor)
105 - y_true (Tensor)
106 - metadata (Tensor)
107 - outputs (Tensor)
108 - y_pred (Tensor)
109 - objective (float)
110 """
111 assert self.is_training
112 # process batch
113 results = self.process_batch(batch)
114 self._update(results)
115 # import pdb;pdb.set_trace()
116 # log results
117 self.update_log(results)
118 return self.sanitize_dict(results)
119
120 def _update(self, results):
121 """
122 Computes the objective and updates the model.
123 Also updates the results dictionary yielded by process_batch().
-
E501
Line too long (87 > 79 characters)
124 Should be overridden to change algorithm update beyond modifying the objective.
125 """
126 # compute objective
127 objective = self.objective(results)
128 results['objective'] = objective.item()
129 # update
130 self.model.zero_grad()
131 objective.backward()
132 if self.max_grad_norm:
133 clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
134 self.optimizer.step()
135 self.step_schedulers(
136 is_epoch=False,
137 metrics=results,
138 log_access=False)