⬅ algorithms/GCL.py source

1 from algorithms.single_model_algorithm import SingleModelAlgorithm
2 from models.initializer import initialize_model
3 import torch
4 from torch_geometric.data import Batch
5 from torch.nn.utils import clip_grad_norm_
6 from utils import move_to
7  
  • E302 Expected 2 blank lines, found 1
8 class GCL(SingleModelAlgorithm):
9 def __init__(self, config, d_out, grouper, loss,
10 metric, n_train_steps):
11 # model = initialize_model(config, d_out).to(config.device)
  • E501 Line too long (135 > 79 characters)
12 featurizer, projector, classifier, combined_model = initialize_model(config, d_out, is_featurizer=True, include_projector=True)
13 featurizer = featurizer.to(config.device)
14 projector = projector.to(config.device)
15 classifier = classifier.to(config.device)
16 model = combined_model.to(config.device)
17 # initialize module
18 super().__init__(
19 config=config,
20 model=model,
21 grouper=grouper,
22 loss=loss,
23 metric=metric,
24 n_train_steps=n_train_steps,
25 )
26 # algorithm hyperparameters
27 self.use_cl = config.use_cl
28 self.aug_prob = config.gcl_aug_prob
29 self.aug_type = config.aug_type
30 # If use_cl=False and aug_ratio is 0.0, should be equiv to ERM
31 self.aug_ratio = config.gcl_aug_ratio
32 # set model components
33 self.featurizer = featurizer
34 self.projector = projector
35 self.classifier = classifier
36 # self.model = model #set by base class I think
37  
  • W291 Trailing whitespace
38 # additional logging, copied from deepCORAL,
39 # but throws error at end of epoch bc its "missing"
40 # self.logged_fields.append('gsimclr_loss')
41  
42 def drop_nodes(self, data):
  • E265 Block comment should start with '# '
43 #print("Node Drop!")
44 """
45 https://github.com/Shen-Lab/GraphCL/blob/1d43f79d7f33f8133f9d4b4b8254d8aaeb09a615/semisupervised_TU/pre-training/tu_dataset.py#L139
  • W291 Trailing whitespace
46 Original method operates on
47 the adjacency matrix like so
48  
49 idx_drop = idx_perm[:drop_num]
50 idx_nondrop = idx_perm[drop_num:]
51 ...
52 edge_index = data.edge_index.numpy()
53 adj = torch.zeros((node_num, node_num))
54 adj[edge_index[0], edge_index[1]] = 1
55 adj = adj[idx_nondrop, :][:, idx_nondrop]
56 edge_index = adj.nonzero().t()
57  
58 This is nice, and pretty efficient, but here we have
  • W291 Trailing whitespace
59 to directly operate on the node index+edge lists,
60 to be careful to keep the edge attributes aligned
61 """
62  
63 aug_ratio = self.aug_ratio
64 node_num = data.x.size()[0]
65 _, edge_num = data.edge_index.size()
66  
67 # Directly model the uniform drop prob over nodes
  • E221 Multiple spaces before operator
68 drop_num = int(node_num * aug_ratio)
69 idx_perm = torch.randperm(node_num)
  • F841 Local variable 'idx_drop' is assigned to but never used
70 idx_drop = idx_perm[:drop_num]
  • E261 At least two spaces before inline comment
  • E501 Line too long (86 > 79 characters)
71 idx_nondrop = idx_perm[drop_num:].sort().values.cuda() # sort for humans/debug
72  
73 # Realize this ^ as a subselecting of the edges in the graph,
  • W291 Trailing whitespace
74 # Noting that this isn't an elegant process because we need to
75 # preserve the orig edge attributes
76 orig_edge_index = data.edge_index
77  
  • E231 Missing whitespace after ',' (in 2 places)
  • E501 Line too long (87 > 79 characters)
78 src_subselect = torch.nonzero(idx_nondrop[...,None] == orig_edge_index[0])[:,1]
  • E231 Missing whitespace after ',' (in 2 places)
  • E501 Line too long (88 > 79 characters)
79 dest_subselect = torch.nonzero(idx_nondrop[...,None] == orig_edge_index[1])[:,1]
  • E501 Line too long (116 > 79 characters)
  • E231 Missing whitespace after ','
80 edge_subselect = src_subselect[torch.nonzero(src_subselect[..., None] == dest_subselect)[:,0]].sort().values
81  
  • E231 Missing whitespace after ','
82 new_edge_index = orig_edge_index[:,edge_subselect]
83 _, new_edge_num = new_edge_index.size()
84  
85 if data.edge_attr is not None:
86 orig_edge_attr = data.edge_attr
  • W291 Trailing whitespace
87 new_edge_attr = orig_edge_attr[edge_subselect]
88  
89 # This would only hold for undirected graph datasets
90 # where the assumption is we are removing undirected edge pairs i<->j
91 # assert (edge_num-new_edge_num)%2 is 0
92  
93 try:
94 data.edge_attr = new_edge_attr
95 data.edge_index = new_edge_index
96 # data.x = ... We do not modify the node features
  • E722 Do not use bare 'except'
97 except:
98 pass
99 # data = data
100  
101 def permute_edges(self, data):
  • E265 Block comment should start with '# '
102 #print("Edge Permute!")
103 """
104 Ported from same repo as drop_nodes
105  
  • W291 Trailing whitespace
106 Current:
  • W291 Trailing whitespace
107 In-place replacement k edges (i->j) with k edges (i'->j')
108 where i and j are uniformly generated over the node idx, i' != j'
109  
110 TODO 1:
  • W291 Trailing whitespace
111 Currently the edge_attr of (i->j) in inherited by (i'->j')
  • W291 Trailing whitespace
112 for datasets that include edge_attr, this is not very
  • W291 Trailing whitespace
113 semantically sound, as the joint distribution between
  • W291 Trailing whitespace
114 (i_feat,j_feat,e_attr), i_feat,j_feat in {node_features},
115 e_attr in {edge_attr_types}, is not being respected
116  
117 TODO 2:
  • W291 Trailing whitespace
118 A non-naive, 'paired' version will also rely on assumption
  • W291 Trailing whitespace
119 that this is a default pytorch geometric data object where
  • W291 Trailing whitespace
120 edge ordering is such that two the directions of a single
  • W291 Trailing whitespace
121 edge occur in pairs. So that we can operate on bidirectional
122 edges easily by stride-2 indexing
123 Eg. data.edge_index = tensor([[0, 1, 1, 2, 2, 3],
124 [1, 0, 2, 1, 3, 2]])
125  
126 """
127 aug_ratio = self.aug_ratio
128 node_num = data.x.size()[0]
129 _, edge_num = data.edge_index.size()
130  
  • E261 At least two spaces before inline comment
131 paired_perms = False # for later, TODO
  • E701 Multiple statements on one line (colon)
  • F632 Use ==/!= to compare constant literals (str, bytes, int, float, tuple)
  • E228 Missing whitespace around modulo operator
  • E261 At least two spaces before inline comment
132 if paired_perms: assert edge_num%2 is 0 # undirected g, paired edges
133  
134 permute_num = int(edge_num * aug_ratio)
135  
136 orig_edge_index = data.edge_index
137  
138 if permute_num > 0:
  • E231 Missing whitespace after ','
  • E501 Line too long (114 > 79 characters)
  • W291 Trailing whitespace
139 edges_to_insert = torch.multinomial(torch.ones(permute_num,node_num), 2, replacement=False).t().cuda()
  • E222 Multiple spaces after operator
  • E501 Line too long (104 > 79 characters)
140 insertion_indices = torch.multinomial(torch.ones(edge_num), permute_num, replacement=False)
141  
142 if paired_perms:
  • W291 Trailing whitespace
143 raise NotImplementedError
144  
  • E231 Missing whitespace after ','
  • E261 At least two spaces before inline comment
  • E501 Line too long (86 > 79 characters)
145 orig_edge_index[:,insertion_indices] = edges_to_insert # modified in place
146 else:
147 # augmentation silently does nothing
148 pass
149  
150 # May not be in-place for paired/non-naive version
151 # try:
152 # data.edge_index = new_edge_index
153 # except:
154 # pass
155  
156 def extract_subgraph(self, data):
157 """
158 TODO Would be nice to have the 3rd, as they all model distinct
159 ways of conceptualizing types of graph 'semantemes'
160  
161 paper uses a random walk to generate a subgraph of nodes
162 """
163 raise NotImplementedError
164  
165 def gsimclr_loss(self, z1, z2):
166 # Temporarily put here in framework to minimize sprawl of changes
167 # z1 and z2 are the projector embeddings for a pair of graphs
  • E501 Line too long (99 > 79 characters)
168 # See Appendix A Algo 1 https://yyou1996.github.io/files/neurips2020_graphcl_supplement.pdf
  • E501 Line too long (163 > 79 characters)
169 # Code from: https://github.com/Shen-Lab/GraphCL/blob/d857849d51bb168568267e07007c0b0c8bb6d869/transferLearning_MoleculeNet_PPI/bio/pretrain_graphcl.py#L57
170 T = 0.1
171 batch_size, _ = z1.size()
172 z1_abs = z1.norm(dim=1)
173 z2_abs = z2.norm(dim=1)
174  
  • E501 Line too long (96 > 79 characters)
175 sim_matrix = torch.einsum('ik,jk->ij', z1, z2) / torch.einsum('i,j->ij', z1_abs, z2_abs)
176 sim_matrix = torch.exp(sim_matrix / T)
177 pos_sim = sim_matrix[range(batch_size), range(batch_size)]
178 loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim)
179 loss = - torch.log(loss).mean()
180 return loss
181  
182 def update(self, batch):
183  
184 #######################################################################
185 # NOTE How the augmentations are applied in GCL paper codebase...
186 #######################################################################
  • W291 Trailing whitespace
187 # 1) In one version of the pretraining reference code,
  • W291 Trailing whitespace
188 # augmentations are applied in the get(idx) method of
189 # the PyG InMemoryDataset subclass being used. This then always returns
  • W291 Trailing whitespace
190 # the pair of data objects, data, data_aug
  • E501 Line too long (124 > 79 characters)
191 # See: https://github.com/Shen-Lab/GraphCL/blob/e9e598d478d4a4bff94a3e95a078569c028f1d88/unsupervised_TU/aug.py#L203
192 # 2) In a second version, there are 2 dataloaders created and zipped
  • E501 Line too long (84 > 79 characters)
193 # and since they are not shuffled, the orig and augmented graphs are aligned
  • E501 Line too long (158 > 79 characters)
194 # See: https://github.com/Shen-Lab/GraphCL/blob/d857849d51bb168568267e07007c0b0c8bb6d869/transferLearning_MoleculeNet_PPI/chem/pretrain_graphcl.py#L80
195 # Thus ... with a two argument loss fn gsimclr_loss(z1, z2), and nice
  • E501 Line too long (82 > 79 characters)
196 # formulation using einsum and subtracting of the diagonal to separate out
  • E501 Line too long (82 > 79 characters)
197 # positive pairs from negative pairs, you can compute the contrastive loss
198  
199 # Currently we have the node dropping and edge permuting augmentations
200 # implemented, with intention to add subgraph if time allows/useful.
201 # See the permute edges function for commentary on the way the original
202 # edge permutation was implemented.
203  
204 # For shortest path to implementation in this framework as it stands,
  • W291 Trailing whitespace
205 # we're going to start by just applying the augmentation here in line,
206 # after the dataloader has returned a batch of graphs, however
  • E501 Line too long (83 > 79 characters)
207 # initial intuition is that this might be pretty inefficient since it can't
  • W291 Trailing whitespace
208 # be parallelized by the dataloader workers ... instead it's blocking
  • E501 Line too long (82 > 79 characters)
209 # during training, though, the computations are all torch or native python
210 #######################################################################
211  
212 assert self.is_training
213 # process batch
214  
215 x, y_true, metadata = batch
216 x = move_to(x, self.device)
217 y_true = move_to(y_true, self.device)
218 g = move_to(self.grouper.metadata_to_group(metadata), self.device)
219 results = {
220 'g': g,
221 'y_true': y_true,
222 'metadata': metadata,
223 }
224  
225 # Augmentation options
226 # currently singular, TODO make composeable
227 augmentations = {
228 'node_drop': self.drop_nodes,
229 'edge_perm': self.permute_edges
230 }
231 aug_list = list(augmentations.values())
232  
  • E222 Multiple spaces after operator
233 if self.aug_type == 'random':
234 def rand_aug(data):
  • E231 Missing whitespace after ','
235 n = torch.randint(2,(1,)).item()
236 fn = aug_list[n]
237 return fn(data)
238 aug_fn = rand_aug
239 else:
240 if self.aug_type in augmentations:
241 aug_fn = augmentations[self.aug_type]
242 else:
243 raise NotImplementedError
244  
  • W293 Blank line contains whitespace
245
  • E303 Too many blank lines (2)
246 batch_size = x.num_graphs
247 # torch.multinomial is slow so we want to do this once per batch
  • E501 Line too long (114 > 79 characters)
248 aug_mask = torch.multinomial(torch.tensor((1-self.aug_prob, self.aug_prob)), batch_size, replacement=True)
249  
  • W291 Trailing whitespace
250 # if not "use_cl" unpack batch of graphs and apply augmentation
251 # to each graph, in place
252 # TODO, if "use_cl" create orig, aug pairs,
253 # extract features,projections,outputs
254 # h1, h2, z1, z2, y_pred1, y_pred2
255 # then modify opjective to call new gsimclr loss
256  
257 graphs = x.to_data_list()
  • W293 Blank line contains whitespace
258
  • W291 Trailing whitespace
259 for i in range(batch_size):
260 if aug_mask[i]:
261 aug_fn(graphs[i])
262 else:
  • E265 Block comment should start with '# '
263 #print("UNCHANGED!")
  • E261 At least two spaces before inline comment
264 pass # original graph kept
265  
266 x = Batch.from_data_list(graphs)
  • W293 Blank line contains whitespace
267
268 # Continue as with other methods/ERM
269 outputs = self.model(x)
270 results['y_pred'] = outputs
271  
272 objective = self.objective(results)
273 results['objective'] = objective.item()
274  
275 # placeholder
276 # gsimclr_loss = 0.
277 # if isinstance(gsimclr_loss, torch.Tensor):
278 # results['gsimclr_loss'] = gsimclr_loss.item()
279 # else:
280 # results['gsimclr_loss'] = gsimclr_loss
281  
282 self.optimizer.zero_grad()
283 self.model.zero_grad()
284 objective.backward()
285 if self.max_grad_norm:
286 clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
287 self.optimizer.step()
288  
289 self.step_schedulers(
290 is_epoch=False,
291 metrics=results,
292 log_access=False)
293  
294 # log results
295 self.update_log(results)
  • E222 Multiple spaces after operator
296 sanitized = self.sanitize_dict(results)
297 return sanitized
298  
299 def objective(self, results):
  • E501 Line too long (89 > 79 characters)
300 return self.loss.compute(results['y_pred'], results['y_true'], return_dict=False)