⬅ algorithms/deepCORAL.py source

1 import torch
2 from algorithms.single_model_algorithm import SingleModelAlgorithm
3 from models.initializer import initialize_model
  • F401 'torch_geometric.nn.global_mean_pool' imported but unused
4 from torch_geometric.nn import global_mean_pool
5  
6 from gds.common.utils import split_into_groups
7  
  • E302 Expected 2 blank lines, found 1
8 class DeepCORAL(SingleModelAlgorithm):
9 """
10 Deep CORAL.
  • E501 Line too long (90 > 79 characters)
11 This algorithm was originally proposed as an unsupervised domain adaptation algorithm.
12 Original paper:
13 @inproceedings{sun2016deep,
14 title={Deep CORAL: Correlation alignment for deep domain adaptation},
15 author={Sun, Baochen and Saenko, Kate},
16 booktitle={European Conference on Computer Vision},
17 pages={443--450},
18 year={2016},
19 organization={Springer}
20 }
  • E501 Line too long (80 > 79 characters)
21 The CORAL penalty function below is adapted from DomainBed's implementation:
22 https://github.com/facebookresearch/DomainBed/blob/1a61f7ff44b02776619803a1dd12f952528ca531/domainbed/algorithms.py#L539
23 """
24  
25 def __init__(self, config, d_out, grouper, loss, metric, n_train_steps):
26 # check config
27 assert config.train_loader == 'group'
28 assert config.uniform_over_groups
29 assert config.distinct_groups
30 # initialize models
  • E501 Line too long (90 > 79 characters)
31 featurizer, classifier = initialize_model(config, d_out=d_out, is_featurizer=True)
32 featurizer = featurizer.to(config.device)
33 classifier = classifier.to(config.device)
34 model = torch.nn.Sequential(featurizer, classifier).to(config.device)
35 # initialize module
36 super().__init__(
37 config=config,
38 model=model,
39 grouper=grouper,
40 loss=loss,
41 metric=metric,
42 n_train_steps=n_train_steps,
43 )
44 # algorithm hyperparameters
45 self.penalty_weight = config.coral_penalty_weight
46 # additional logging
47 self.logged_fields.append('penalty')
48 # set model components
49 self.featurizer = featurizer
50 self.classifier = classifier
51  
52 def coral_penalty(self, x, y):
53 if x.dim() > 2:
  • E501 Line too long (91 > 79 characters)
54 # featurizers output Tensors of size (batch_size, ..., feature dimensionality).
55 # we flatten to Tensors of size (*, feature dimensionality)
56 x = x.view(-1, x.size(-1))
57 y = y.view(-1, y.size(-1))
58  
59 mean_x = x.mean(0, keepdim=True)
60 mean_y = y.mean(0, keepdim=True)
61 cent_x = x - mean_x
62 cent_y = y - mean_y
63 cova_x = (cent_x.t() @ cent_x) / (len(x) - 1)
64 cova_y = (cent_y.t() @ cent_y) / (len(y) - 1)
65  
66 mean_diff = (mean_x - mean_y).pow(2).mean()
67 cova_diff = (cova_x - cova_y).pow(2).mean()
68  
69 return mean_diff + cova_diff
70  
71 def process_batch(self, batch):
72 """
73 Override
74 """
75 # forward pass
76 x, y_true, metadata = batch
77 x = x.to(self.device)
78 y_true = y_true.to(self.device)
79 g = self.grouper.metadata_to_group(metadata).to(self.device)
80 features = self.featurizer(x)
81 outputs = self.classifier(features)
82  
83 # package the results
84 results = {
85 'g': g,
86 'y_true': y_true,
87 'y_pred': outputs,
88 'metadata': metadata,
89 'features': features,
90 }
91 return results
92  
93 def objective(self, results):
94 # extract features
95 features = results.pop('features')
96  
97 if self.is_training:
98 # split into groups
99 unique_groups, group_indices, _ = split_into_groups(results['g'])
100 # compute penalty
101 n_groups_per_batch = unique_groups.numel()
102 penalty = torch.zeros(1, device=self.device)
103 for i_group in range(n_groups_per_batch):
104 for j_group in range(i_group + 1, n_groups_per_batch):
  • E501 Line too long (117 > 79 characters)
105 penalty += self.coral_penalty(features[group_indices[i_group]], features[group_indices[j_group]])
106 if n_groups_per_batch > 1:
  • E501 Line too long (102 > 79 characters)
107 penalty /= (n_groups_per_batch * (n_groups_per_batch - 1) / 2) # get the mean penalty
108 # save penalty
109 else:
110 penalty = 0.
111  
112 if isinstance(penalty, torch.Tensor):
113 results['penalty'] = penalty.item()
114 else:
115 results['penalty'] = penalty
116  
  • E501 Line too long (93 > 79 characters)
117 avg_loss = self.loss.compute(results['y_pred'], results['y_true'], return_dict=False)
118  
119 return avg_loss + penalty * self.penalty_weight