experiments.algorithms.algorithm

experiments/algorithms/algorithm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch.nn as nn
from utils import move_to, detach_and_clone


class Algorithm(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.device = device
        self.out_device = 'cpu'
        self._has_log = False
        self.reset_log()

    def update(self, batch):
        """
        Process the batch, update the log, and update the model
        Args:
            - batch (tuple of Tensors): a batch of data yielded by data loaders
        Output:
            - results (dictionary): information about the batch, such as:
                - g (Tensor)
                - y_true (Tensor)
                - metadata (Tensor)
                - loss (Tensor)
                - metrics (Tensor)
        """
        raise NotImplementedError

    def evaluate(self, batch):
        """
        Process the batch and update the log, without updating the model
        Args:
            - batch (tuple of Tensors): a batch of data yielded by data loaders
        Output:
            - results (dictionary): information about the batch, such as:
                - g (Tensor)
                - y_true (Tensor)
                - metadata (Tensor)
                - loss (Tensor)
                - metrics (Tensor)
        """
        raise NotImplementedError

    def train(self, mode=True):
        """
        Switch to train mode
        """
        self.is_training = mode
        super().train(mode)
        self.reset_log()

    @property
    def has_log(self):
        return self._has_log

    def reset_log(self):
        """
        Resets log by clearing out the internal log, Algorithm.log_dict
        """
        self._has_log = False
        self.log_dict = {}

    def update_log(self, results):
        """
        Updates the internal log, Algorithm.log_dict
        Args:
            - results (dictionary)
        """
        raise NotImplementedError

    def get_log(self):
        """
        Sanitizes the internal log (Algorithm.log_dict) and outputs it.

        """
        raise NotImplementedError

    def get_pretty_log_str(self):
        raise NotImplementedError

    def step_schedulers(self, is_epoch, metrics={}, log_access=False):
        """
        Update all relevant schedulers
        Args:
            - is_epoch (bool): epoch-wise update if set to True, batch-wise update otherwise
            - metrics (dict): a dictionary of metrics that can be used for scheduler updates
            - log_access (bool): whether metrics from self.get_log() can be used to update schedulers
        """
        raise NotImplementedError

    def sanitize_dict(self, in_dict, to_out_device=True):
        """
        Helper function that sanitizes dictionaries by:
            - moving to the specified output device
            - removing any gradient information
            - detaching and cloning the tensors
        Args:
            - in_dict (dictionary)
        Output:
            - out_dict (dictionary): sanitized version of in_dict
        """
        out_dict = detach_and_clone(in_dict)
        if to_out_device:
            out_dict = move_to(out_dict, self.out_device)
        return out_dict