⬅ algorithms/algorithm.py source

1 import torch.nn as nn
2 from utils import move_to, detach_and_clone
3  
4  
5 class Algorithm(nn.Module):
6 def __init__(self, device):
7 super().__init__()
8 self.device = device
9 self.out_device = 'cpu'
10 self._has_log = False
11 self.reset_log()
12  
13 def update(self, batch):
14 """
15 Process the batch, update the log, and update the model
16 Args:
17 - batch (tuple of Tensors): a batch of data yielded by data loaders
18 Output:
19 - results (dictionary): information about the batch, such as:
20 - g (Tensor)
21 - y_true (Tensor)
22 - metadata (Tensor)
23 - loss (Tensor)
24 - metrics (Tensor)
25 """
26 raise NotImplementedError
27  
28 def evaluate(self, batch):
29 """
30 Process the batch and update the log, without updating the model
31 Args:
32 - batch (tuple of Tensors): a batch of data yielded by data loaders
33 Output:
34 - results (dictionary): information about the batch, such as:
35 - g (Tensor)
36 - y_true (Tensor)
37 - metadata (Tensor)
38 - loss (Tensor)
39 - metrics (Tensor)
40 """
41 raise NotImplementedError
42  
43 def train(self, mode=True):
44 """
45 Switch to train mode
46 """
47 self.is_training = mode
48 super().train(mode)
49 self.reset_log()
50  
51 @property
52 def has_log(self):
53 return self._has_log
54  
55 def reset_log(self):
56 """
57 Resets log by clearing out the internal log, Algorithm.log_dict
58 """
59 self._has_log = False
60 self.log_dict = {}
61  
62 def update_log(self, results):
63 """
64 Updates the internal log, Algorithm.log_dict
65 Args:
66 - results (dictionary)
67 """
68 raise NotImplementedError
69  
70 def get_log(self):
71 """
72 Sanitizes the internal log (Algorithm.log_dict) and outputs it.
73  
74 """
75 raise NotImplementedError
76  
77 def get_pretty_log_str(self):
78 raise NotImplementedError
79  
80 def step_schedulers(self, is_epoch, metrics={}, log_access=False):
81 """
82 Update all relevant schedulers
83 Args:
  • E501 Line too long (92 > 79 characters)
84 - is_epoch (bool): epoch-wise update if set to True, batch-wise update otherwise
  • E501 Line too long (92 > 79 characters)
85 - metrics (dict): a dictionary of metrics that can be used for scheduler updates
  • E501 Line too long (101 > 79 characters)
86 - log_access (bool): whether metrics from self.get_log() can be used to update schedulers
87 """
88 raise NotImplementedError
89  
90 def sanitize_dict(self, in_dict, to_out_device=True):
91 """
92 Helper function that sanitizes dictionaries by:
93 - moving to the specified output device
94 - removing any gradient information
95 - detaching and cloning the tensors
96 Args:
97 - in_dict (dictionary)
98 Output:
99 - out_dict (dictionary): sanitized version of in_dict
100 """
101 out_dict = detach_and_clone(in_dict)
102 if to_out_device:
103 out_dict = move_to(out_dict, self.out_device)
104 return out_dict