⬅ models/gnn.py source

1 import torch
2 import torch.nn.functional as F
  • E501 Line too long (92 > 79 characters)
3 from torch_geometric.nn import global_mean_pool, global_add_pool, GCNConv, GINConv, ChebConv
4 from ogb.graphproppred.mol_encoder import AtomEncoder
5  
6 from .conv import GCNConvNew, GINConvNew, ChebConvNew
7  
8 Cheb_K = 3
9  
10  
11 # mol
12 class GNN(torch.nn.Module):
13 """
  • E501 Line too long (100 > 79 characters)
14 Graph Isomorphism Network augmented with virtual node for multi-task binary graph classification
15 Input:
16 - batched Pytorch Geometric graph object
17 Output:
  • E501 Line too long (82 > 79 characters)
18 - prediction (Tensor): float torch tensor of shape (num_graphs, num_tasks)
19 """
20  
  • E501 Line too long (118 > 79 characters)
21 def __init__(self, gnn_type, dataset_group, num_tasks=128, num_layers=5, emb_dim=300, dropout=0.5, is_pooled=True,
22 **model_kwargs):
23 """
24 Args:
  • E501 Line too long (109 > 79 characters)
25 - num_tasks (int): number of binary label tasks. default to 128 (number of tasks of ogbg-molpcba)
26 - num_layers (int): number of message passing layers of GNN
27 - emb_dim (int): dimensionality of hidden channels
28 - dropout (float): dropout ratio applied to hidden channels
29 """
30 self.gnn_type = gnn_type
31 self.dataset_group = dataset_group
32  
33 super(GNN, self).__init__()
34  
  • E203 Whitespace before ':'
35 if self.gnn_type.endswith('layers') :
36 num_layers = int(self.gnn_type.split('_')[1])
37 residual = True
  • E203 Whitespace before ':'
38 else :
39 residual = False
40  
41 self.num_layers = num_layers
42 self.dropout = dropout
43 self.emb_dim = emb_dim
44 self.num_tasks = num_tasks
45 self.is_pooled = is_pooled
46  
47 if num_tasks is None:
48 self.d_out = self.emb_dim
49 else:
50 self.d_out = self.num_tasks
51  
52 if self.num_layers < 2:
53 raise ValueError("Number of GNN layers must be greater than 1.")
54  
55 if self.gnn_type.endswith('virtual'):
  • E501 Line too long (103 > 79 characters)
56 self.gnn_node = GNN_node_Virtualnode(num_layers, emb_dim, dataset_group=self.dataset_group,
  • E501 Line too long (103 > 79 characters)
57 gnn_type=self.gnn_type.split('_')[0], dropout=dropout,
58 residual=residual)
59 else:
  • E501 Line too long (91 > 79 characters)
60 self.gnn_node = GNN_node(num_layers, emb_dim, dataset_group=self.dataset_group,
  • E501 Line too long (91 > 79 characters)
61 gnn_type=self.gnn_type.split('_')[0], dropout=dropout,
62 residual=residual)
63  
64 # Pooling function to generate whole-graph embeddings
65 if self.is_pooled:
66 self.pool = global_mean_pool
67 else:
68 self.pool = None
69 if num_tasks is None:
70 self.graph_pred_linear = None
71 else:
72 assert self.pool is not None
  • E501 Line too long (82 > 79 characters)
73 self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_tasks)
74  
75 def forward(self, batched_data, perturb=None):
76 h_node = self.gnn_node(batched_data, perturb)
77  
78 if self.graph_pred_linear is None:
79 if self.pool is None:
80 return h_node, batched_data.batch
81 else:
82 return self.pool(h_node, batched_data.batch)
83 else:
  • E501 Line too long (80 > 79 characters)
84 return self.graph_pred_linear(self.pool(h_node, batched_data.batch))
85  
86  
  • E266 Too many leading '#' for block comment
87 ### GNN to generate node embedding
88 class GNN_node(torch.nn.Module):
89 """
90 Output:
91 node representations
92 """
93  
  • E501 Line too long (103 > 79 characters)
94 def __init__(self, num_layer, emb_dim, dataset_group='mol', gnn_type='gin', dropout=0.5, JK="last",
95 residual=False, batchnorm=True):
96 '''
97 emb_dim (int): node embedding dimensionality
98 num_layer (int): number of GNN message passing layers
99 '''
100  
101 super(GNN_node, self).__init__()
102  
103 self.dataset_group = dataset_group
104 self.gnn_type = gnn_type
105  
106 self.num_layer = num_layer
107 self.drop_ratio = dropout
108 self.JK = JK
  • E266 Too many leading '#' for block comment
109 ### add residual connection or not
110 self.residual = residual
111 self.batchnorm = batchnorm
112  
113 if self.num_layer < 2:
114 raise ValueError("Number of GNN layers must be greater than 1.")
115  
116 if self.dataset_group == 'mol':
117 self.node_encoder = AtomEncoder(emb_dim)
118 elif self.dataset_group == 'ppa':
  • E261 At least two spaces before inline comment
  • E501 Line too long (93 > 79 characters)
119 self.node_encoder = torch.nn.Embedding(1, emb_dim) # uniform input node embedding
120 elif self.dataset_group == 'RotatedMNIST':
121 self.node_encoder = torch.nn.Linear(1, emb_dim)
  • E203 Whitespace before ':'
122 elif self.dataset_group == 'ColoredMNIST' :
123 self.node_encoder = torch.nn.Linear(2, emb_dim)
124 # self.node_encoder_cate = torch.nn.Embedding(8, emb_dim)
  • E203 Whitespace before ':'
125 elif self.dataset_group == 'SBM' :
126 self.node_encoder = torch.nn.Embedding(8, emb_dim)
127 elif self.dataset_group == 'UPFD':
128 self.node_encoder = torch.nn.Embedding(8, emb_dim)
129 else:
130 raise NotImplementedError
131  
  • E265 Block comment should start with '# '
132 ###List of GNNs
133 self.convs = torch.nn.ModuleList()
134 self.batch_norms = torch.nn.ModuleList()
135  
136 for layer in range(num_layer):
137 if gnn_type == 'gin':
  • E271 Multiple spaces after keyword
  • E501 Line too long (90 > 79 characters)
138 if self.dataset_group in ['RotatedMNIST', 'ColoredMNIST', 'SBM', 'UPFD']:
  • E501 Line too long (84 > 79 characters)
139 mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim),
  • E501 Line too long (97 > 79 characters)
140 torch.nn.BatchNorm1d(2 * emb_dim), torch.nn.ReLU(),
  • E501 Line too long (84 > 79 characters)
141 torch.nn.Linear(2 * emb_dim, emb_dim))
142 self.convs.append(GINConv(mlp, train_eps=True))
143 else:
144 self.convs.append(GINConvNew(emb_dim, self.dataset_group))
145 elif gnn_type == 'gcn':
  • E271 Multiple spaces after keyword
  • E501 Line too long (90 > 79 characters)
146 if self.dataset_group in ['RotatedMNIST', 'ColoredMNIST', 'SBM', 'UPFD']:
147 self.convs.append(GCNConv(emb_dim, emb_dim))
148 else:
149 self.convs.append(GCNConvNew(emb_dim, self.dataset_group))
  • E203 Whitespace before ':'
150 elif gnn_type == 'cheb' :
  • E271 Multiple spaces after keyword
  • E501 Line too long (90 > 79 characters)
151 if self.dataset_group in ['RotatedMNIST', 'ColoredMNIST', 'SBM', 'UPFD']:
152 self.convs.append(ChebConv(emb_dim, emb_dim, Cheb_K))
153 else:
  • E501 Line too long (87 > 79 characters)
154 self.convs.append(ChebConvNew(emb_dim, Cheb_K, self.dataset_group))
155 else:
  • E501 Line too long (81 > 79 characters)
156 raise ValueError('Undefined GNN type called {}'.format(gnn_type))
157  
158 self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
159  
160 def destroy_node_encoder(self):
161 self.node_encoder = None
162  
163 def forward(self, batched_data, perturb=None):
  • F841 Local variable 'batch' is assigned to but never used
  • E501 Line too long (125 > 79 characters)
164 x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch
165  
166 # FLAG injects perturbation
167 if self.node_encoder is None:
168 h_list = [x]
169 else:
170 # if self.dataset_group == 'ColoredMNIST' :
  • E501 Line too long (236 > 79 characters)
171 # h_list = [self.node_encoder(x[:,:2]) + self.node_encoder_cate(x[:,2:].to(torch.int).squeeze()) + perturb if perturb is not None else self.node_encoder(x[:,:2]) + self.node_encoder_cate(x[:,2:].to(torch.int).squeeze())]
172 # else :
  • E501 Line too long (102 > 79 characters)
173 h_list = [self.node_encoder(x) + perturb if perturb is not None else self.node_encoder(x)]
174  
175 for layer in range(self.num_layer):
176  
177 h = self.convs[layer](h_list[layer], edge_index, edge_attr)
  • E203 Whitespace before ':'
178 if self.batchnorm :
179 h = self.batch_norms[layer](h)
180  
181 if layer == self.num_layer - 1:
182 # remove relu for the last layer
183 h = F.dropout(h, self.drop_ratio, training=self.training)
184 else:
  • E501 Line too long (81 > 79 characters)
185 h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
186  
187 if self.residual:
188 h += h_list[layer]
189  
190 h_list.append(h)
191  
  • E266 Too many leading '#' for block comment
192 ### Different implementations of Jk-concat
193 if self.JK == "last":
194 node_representation = h_list[-1]
195 elif self.JK == "sum":
196 node_representation = 0
197 for layer in range(self.num_layer + 1):
198 node_representation += h_list[layer]
199  
200 return node_representation
201  
202  
  • E266 Too many leading '#' for block comment
203 ### Virtual GNN to generate node embedding
204 class GNN_node_Virtualnode(torch.nn.Module):
205 """
206 Output:
207 node representations
208 """
209  
  • E501 Line too long (103 > 79 characters)
210 def __init__(self, num_layer, emb_dim, dataset_group='mol', gnn_type='gin', dropout=0.5, JK="last",
211 residual=False, batchnorm=True):
212 '''
213 emb_dim (int): node embedding dimensionality
214 '''
215  
216 super(GNN_node_Virtualnode, self).__init__()
217  
218 self.dataset_group = dataset_group
219 self.gnn_type = gnn_type
220  
221 self.num_layer = num_layer
222 self.drop_ratio = dropout
223 self.JK = JK
  • E266 Too many leading '#' for block comment
224 ### add residual connection or not
225 self.residual = residual
226 self.batchnorm = batchnorm
227  
228 if self.num_layer < 2:
229 raise ValueError("Number of GNN layers must be greater than 1.")
230  
231 if self.dataset_group == 'mol':
232 self.node_encoder = AtomEncoder(emb_dim)
233 elif self.dataset_group == 'ppa':
  • E261 At least two spaces before inline comment
  • E501 Line too long (93 > 79 characters)
234 self.node_encoder = torch.nn.Embedding(1, emb_dim) # uniform input node embedding
235 elif self.dataset_group == 'RotatedMNIST':
236 self.node_encoder = torch.nn.Linear(1, emb_dim)
  • E203 Whitespace before ':'
237 elif self.dataset_group == 'ColoredMNIST' :
238 self.node_encoder = torch.nn.Linear(2, emb_dim)
239 # self.node_encoder_cate = torch.nn.Embedding(8, emb_dim)
  • E203 Whitespace before ':'
240 elif self.dataset_group == 'SBM' :
241 self.node_encoder = torch.nn.Embedding(8, emb_dim)
242 elif self.dataset_group == 'UPFD':
243 self.node_encoder = torch.nn.Embedding(8, emb_dim)
244 else:
245 raise NotImplementedError
246  
  • E266 Too many leading '#' for block comment
247 ### set the initial virtual node embedding to 0.
248 self.virtualnode_embedding = torch.nn.Embedding(1, emb_dim)
249 torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0)
250  
  • E266 Too many leading '#' for block comment
251 ### List of GNNs
252 self.convs = torch.nn.ModuleList()
  • E266 Too many leading '#' for block comment
253 ### batch norms applied to node embeddings
254 self.batch_norms = torch.nn.ModuleList()
255  
  • E266 Too many leading '#' for block comment
256 ### List of MLPs to transform virtual node at every layer
257 self.mlp_virtualnode_list = torch.nn.ModuleList()
258  
259 for layer in range(num_layer):
260 if gnn_type == 'gin':
  • E271 Multiple spaces after keyword
  • E501 Line too long (90 > 79 characters)
261 if self.dataset_group in ['RotatedMNIST', 'ColoredMNIST', 'SBM', 'UPFD']:
  • E501 Line too long (84 > 79 characters)
262 mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim),
  • E501 Line too long (97 > 79 characters)
263 torch.nn.BatchNorm1d(2 * emb_dim), torch.nn.ReLU(),
  • E501 Line too long (84 > 79 characters)
264 torch.nn.Linear(2 * emb_dim, emb_dim))
265 self.convs.append(GINConv(mlp, train_eps=True))
266 else:
267 self.convs.append(GINConvNew(emb_dim, self.dataset_group))
268 elif gnn_type == 'gcn':
  • E271 Multiple spaces after keyword
  • E501 Line too long (90 > 79 characters)
269 if self.dataset_group in ['RotatedMNIST', 'ColoredMNIST', 'SBM', 'UPFD']:
270 self.convs.append(GCNConv(emb_dim, emb_dim))
271 else:
272 self.convs.append(GCNConvNew(emb_dim, self.dataset_group))
  • E203 Whitespace before ':'
273 elif gnn_type == 'cheb' :
  • E271 Multiple spaces after keyword
  • E501 Line too long (90 > 79 characters)
274 if self.dataset_group in ['RotatedMNIST', 'ColoredMNIST', 'SBM', 'UPFD']:
275 self.convs.append(ChebConv(emb_dim, emb_dim, Cheb_K))
276 else:
  • E501 Line too long (87 > 79 characters)
277 self.convs.append(ChebConvNew(emb_dim, Cheb_K, self.dataset_group))
278 else:
  • E501 Line too long (81 > 79 characters)
279 raise ValueError('Undefined GNN type called {}'.format(gnn_type))
280  
281 self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
282  
283 for layer in range(num_layer - 1):
284 self.mlp_virtualnode_list.append(
  • E501 Line too long (109 > 79 characters)
285 torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.BatchNorm1d(2 * emb_dim),
  • E502 The backslash is redundant between brackets
286 torch.nn.ReLU(), \
  • E501 Line too long (105 > 79 characters)
287 torch.nn.Linear(2 * emb_dim, emb_dim), torch.nn.BatchNorm1d(emb_dim),
288 torch.nn.ReLU()))
289  
290 def destroy_node_encoder(self):
291 self.node_encoder = None
292  
293 def forward(self, batched_data, perturb=None):
294  
  • E501 Line too long (125 > 79 characters)
295 x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch
296  
  • E266 Too many leading '#' for block comment
297 ### virtual node embeddings for graphs
298 virtualnode_embedding = self.virtualnode_embedding(
  • E501 Line too long (89 > 79 characters)
299 torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device))
300  
301 # FLAG injects perturbation
302 if self.node_encoder is None:
303 h_list = [x]
304 else:
305 # if self.dataset_group == 'ColoredMNIST' :
  • E501 Line too long (236 > 79 characters)
306 # h_list = [self.node_encoder(x[:,:2]) + self.node_encoder_cate(x[:,2:].to(torch.int).squeeze()) + perturb if perturb is not None else self.node_encoder(x[:,:2]) + self.node_encoder_cate(x[:,2:].to(torch.int).squeeze())]
307 # else :
  • E501 Line too long (102 > 79 characters)
308 h_list = [self.node_encoder(x) + perturb if perturb is not None else self.node_encoder(x)]
309  
310 for layer in range(self.num_layer):
  • E266 Too many leading '#' for block comment
311 ### add message from virtual nodes to graph nodes
312 h_list[layer] = h_list[layer] + virtualnode_embedding[batch]
313  
  • E266 Too many leading '#' for block comment
314 ### Message passing among graph nodes
315 h = self.convs[layer](h_list[layer], edge_index, edge_attr)
316 if self.batchnorm:
317 h = self.batch_norms[layer](h)
318 if layer == self.num_layer - 1:
319 # remove relu for the last layer
320 h = F.dropout(h, self.drop_ratio, training=self.training)
321 else:
  • E501 Line too long (81 > 79 characters)
322 h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
323  
324 if self.residual:
325 h = h + h_list[layer]
326  
327 h_list.append(h)
328  
  • E266 Too many leading '#' for block comment
329 ### update the virtual nodes
330 if layer < self.num_layer - 1:
  • E266 Too many leading '#' for block comment
331 ### add message from graph nodes to virtual nodes
  • E501 Line too long (106 > 79 characters)
332 virtualnode_embedding_temp = global_add_pool(h_list[layer], batch) + virtualnode_embedding
  • E266 Too many leading '#' for block comment
333 ### transform virtual nodes using MLP
334  
335 if self.residual:
336 virtualnode_embedding = virtualnode_embedding + F.dropout(
  • E501 Line too long (102 > 79 characters)
337 self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio,
338 training=self.training)
339 else:
  • E501 Line too long (115 > 79 characters)
340 virtualnode_embedding = F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp),
  • E501 Line too long (94 > 79 characters)
341 self.drop_ratio, training=self.training)
342  
  • E266 Too many leading '#' for block comment
343 ### Different implementations of Jk-concat
344 if self.JK == "last":
345 node_representation = h_list[-1]
346 elif self.JK == "sum":
347 node_representation = 0
348 for layer in range(self.num_layer + 1):
349 node_representation += h_list[layer]
350  
351 return node_representation