experiments.algorithms.FLAG

experiments/algorithms/FLAG.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from algorithms.single_model_algorithm import SingleModelAlgorithm
from models.initializer import initialize_model
import torch
from torch.nn.utils import clip_grad_norm_
from utils import move_to


class FLAG(SingleModelAlgorithm):
    def __init__(self, config, d_out, grouper, loss,
                 metric, n_train_steps):
        model = initialize_model(config, d_out).to(config.device)
        # initialize module
        super().__init__(
            config=config,
            model=model,
            grouper=grouper,
            loss=loss,
            metric=metric,
            n_train_steps=n_train_steps,
        )
        self.config = config

    def update(self, batch):
        step_size = self.config.flag_step_size
        m = 3

        assert self.is_training
        # process batch

        x, y_true, metadata = batch
        x = move_to(x, self.device)
        y_true = move_to(y_true, self.device)
        g = move_to(self.grouper.metadata_to_group(metadata), self.device)
        results = {
            'g': g,
            'y_true': y_true,
            'metadata': metadata,
        }


        perturb_shape = (x.x.shape[0], self.model.emb_dim)
        perturb = torch.FloatTensor(*perturb_shape).uniform_(-step_size, step_size)
        perturb = move_to(perturb, self.device)
        perturb.requires_grad_()

        self.optimizer.zero_grad()

        outputs = self.model(x, perturb)
        results['y_pred'] = outputs
        objective = self.objective(results)
        objective /= m

        for _ in range(m - 1):
            objective.backward()
            perturb_data = perturb.detach() + step_size*torch.sign(perturb.grad.detach())
            perturb.data = perturb_data.data
            perturb.grad[:] = 0

            outputs = self.model(x, perturb)
            results['y_pred'] = outputs
            objective = self.objective(results)
            objective /= m

        objective.backward()
        if self.max_grad_norm:
            clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
        self.optimizer.step()

        results['objective'] = objective.item() * m
        self.step_schedulers(
            is_epoch=False,
            metrics=results,
            log_access=False)

        # log results
        self.update_log(results)
        return self.sanitize_dict(results)

    def objective(self, results):
        return self.loss.compute(results['y_pred'], results['y_true'], return_dict=False)