1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
|
import torch
from algorithms.single_model_algorithm import SingleModelAlgorithm
from models.initializer import initialize_model
from torch_geometric.nn import global_mean_pool
from gds.common.utils import split_into_groups
class DeepCORAL(SingleModelAlgorithm):
"""
Deep CORAL.
This algorithm was originally proposed as an unsupervised domain adaptation algorithm.
Original paper:
@inproceedings{sun2016deep,
title={Deep CORAL: Correlation alignment for deep domain adaptation},
author={Sun, Baochen and Saenko, Kate},
booktitle={European Conference on Computer Vision},
pages={443--450},
year={2016},
organization={Springer}
}
The CORAL penalty function below is adapted from DomainBed's implementation:
https://github.com/facebookresearch/DomainBed/blob/1a61f7ff44b02776619803a1dd12f952528ca531/domainbed/algorithms.py#L539
"""
def __init__(self, config, d_out, grouper, loss, metric, n_train_steps):
# check config
assert config.train_loader == 'group'
assert config.uniform_over_groups
assert config.distinct_groups
# initialize models
featurizer, classifier = initialize_model(config, d_out=d_out, is_featurizer=True)
featurizer = featurizer.to(config.device)
classifier = classifier.to(config.device)
model = torch.nn.Sequential(featurizer, classifier).to(config.device)
# initialize module
super().__init__(
config=config,
model=model,
grouper=grouper,
loss=loss,
metric=metric,
n_train_steps=n_train_steps,
)
# algorithm hyperparameters
self.penalty_weight = config.coral_penalty_weight
# additional logging
self.logged_fields.append('penalty')
# set model components
self.featurizer = featurizer
self.classifier = classifier
def coral_penalty(self, x, y):
if x.dim() > 2:
# featurizers output Tensors of size (batch_size, ..., feature dimensionality).
# we flatten to Tensors of size (*, feature dimensionality)
x = x.view(-1, x.size(-1))
y = y.view(-1, y.size(-1))
mean_x = x.mean(0, keepdim=True)
mean_y = y.mean(0, keepdim=True)
cent_x = x - mean_x
cent_y = y - mean_y
cova_x = (cent_x.t() @ cent_x) / (len(x) - 1)
cova_y = (cent_y.t() @ cent_y) / (len(y) - 1)
mean_diff = (mean_x - mean_y).pow(2).mean()
cova_diff = (cova_x - cova_y).pow(2).mean()
return mean_diff + cova_diff
def process_batch(self, batch):
"""
Override
"""
# forward pass
x, y_true, metadata = batch
x = x.to(self.device)
y_true = y_true.to(self.device)
g = self.grouper.metadata_to_group(metadata).to(self.device)
features = self.featurizer(x)
outputs = self.classifier(features)
# package the results
results = {
'g': g,
'y_true': y_true,
'y_pred': outputs,
'metadata': metadata,
'features': features,
}
return results
def objective(self, results):
# extract features
features = results.pop('features')
if self.is_training:
# split into groups
unique_groups, group_indices, _ = split_into_groups(results['g'])
# compute penalty
n_groups_per_batch = unique_groups.numel()
penalty = torch.zeros(1, device=self.device)
for i_group in range(n_groups_per_batch):
for j_group in range(i_group + 1, n_groups_per_batch):
penalty += self.coral_penalty(features[group_indices[i_group]], features[group_indices[j_group]])
if n_groups_per_batch > 1:
penalty /= (n_groups_per_batch * (n_groups_per_batch - 1) / 2) # get the mean penalty
# save penalty
else:
penalty = 0.
if isinstance(penalty, torch.Tensor):
results['penalty'] = penalty.item()
else:
results['penalty'] = penalty
avg_loss = self.loss.compute(results['y_pred'], results['y_true'], return_dict=False)
return avg_loss + penalty * self.penalty_weight
|