⬅ models/gsn/gnn.py source

  • F401 'copy' imported but unused
1 import copy
2  
3 import torch
4 import torch.nn as nn
5 import torch.nn.functional as F
6  
7 from .graph_filters.GSN_edge_sparse_ogb import GSN_edge_sparse_ogb
8 from .graph_filters.GSN_sparse import GSN_sparse
9  
10 from .graph_filters.MPNN_edge_sparse_ogb import MPNN_edge_sparse_ogb
11  
  • F401 '.models_misc.mlp' imported but unused
12 from .models_misc import mlp, choose_activation
  • E501 Line too long (100 > 79 characters)
13 from .utils_graph_learning import global_add_pool_sparse, global_mean_pool_sparse, DiscreteEmbedding
14  
  • F401 'ogb.graphproppred.mol_encoder.AtomEncoder' imported but unused
  • F401 'ogb.graphproppred.mol_encoder.BondEncoder' imported but unused
15 from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
16  
17  
18 class GNN_GSN(torch.nn.Module):
19  
20 def __init__(self,
21 in_features,
22 out_features,
23 encoder_ids,
24 d_in_id,
25 in_edge_features=None,
26 d_in_node_encoder=None,
27 d_in_edge_encoder=None,
28 d_degree=None,
29 emb_dim=300,
30 num_layers=5,
31 dataset_group='mol'):
32  
33 super(GNN_GSN, self).__init__()
34  
35 seed = 0
36  
37 # -------------- Initializations
38 self.dataset_group = dataset_group
39 self.emb_dim = emb_dim
40 self.num_layers = num_layers
41 self.model_name = 'GSN_edge_sparse_ogb'
42 self.readout = 'mean'
43 self.dropout_features = 0.5
44 self.bn = True
45  
46 self.final_projection = [False for _ in range(self.num_layers)]
  • E501 Line too long (111 > 79 characters)
47 self.final_projection.append(True) # self.final_projection = [False, False, False, False, False, True]
48  
49 self.residual = False
50 self.inject_ids = False
51  
52 id_scope = 'local'
53 d_msg = emb_dim
54 d_out = emb_dim
55 d_h = [2 * emb_dim]
56 aggr = 'add'
57 flow = 'source_to_target'
58 msg_kind = 'ogb'
59 train_eps = False
60 activation_mlp = 'relu'
61 bn_mlp = True
  • F841 Local variable 'jk_mlp' is assigned to but never used
62 jk_mlp = False
  • F841 Local variable 'degree_embedding' is assigned to but never used
63 degree_embedding = None
64  
65 retain_features = [True for _ in range(self.num_layers)]
  • E501 Line too long (87 > 79 characters)
66 retain_features[0] = False # retain_features = [False, True, True, True, True]
67  
68 degree_as_tag = False
69 encoders_kwargs = {'seed': seed,
70 'activation_mlp': activation_mlp,
71 'bn_mlp': bn_mlp,
72 'aggr': 'sum',
73 'features_scope': 'full'}
74  
  • E203 Whitespace before ':'
75 if self.dataset_group == 'mol' :
76 self.input_node_encoder = DiscreteEmbedding('atom_encoder',
77 in_features,
78 d_in_node_encoder,
79 emb_dim,
80 **encoders_kwargs)
81 d_in = self.input_node_encoder.d_out
  • E203 Whitespace before ':'
82 elif self.dataset_group == 'ppa' :
83 self.input_node_encoder = torch.nn.Embedding(1, emb_dim)
84 d_in = emb_dim
85 elif self.dataset_group == 'RotatedMNIST':
86 self.node_encoder = torch.nn.Linear(3, emb_dim)
87 d_in = emb_dim
  • E203 Whitespace before ':'
88 else :
89 raise NotImplementedError
90  
91 # -------------- Edge embedding (for each GNN layer)
92 self.edge_encoder = []
93 d_ef = []
  • E203 Whitespace before ':'
94 if self.dataset_group == 'mol' :
95 for i in range(self.num_layers):
96 edge_encoder_layer = DiscreteEmbedding('bond_encoder',
97 in_edge_features,
98 d_in_edge_encoder,
99 emb_dim,
100 **encoders_kwargs)
101 self.edge_encoder.append(edge_encoder_layer)
102 d_ef.append(edge_encoder_layer.d_out)
  • E203 Whitespace before ':'
103 elif self.dataset_group == 'ppa' :
104 for i in range(self.num_layers):
105 edge_encoder_layer = torch.nn.Linear(7, emb_dim)
106 self.edge_encoder.append(edge_encoder_layer)
107 d_ef.append(emb_dim)
  • E203 Whitespace before ':'
108 else :
109 pass
110  
111 self.edge_encoder = nn.ModuleList(self.edge_encoder)
112  
113 # -------------- Identifier embedding (for each GNN layer)
114 self.id_encoder = []
115 d_id = []
116 num_id_encoders = self.num_layers if self.inject_ids else 1
117  
118 for i in range(num_id_encoders):
119 id_encoder_layer = DiscreteEmbedding('embedding',
120 len(d_in_id),
121 d_in_id,
122 emb_dim,
123 **encoders_kwargs)
124 self.id_encoder.append(id_encoder_layer)
125 d_id.append(id_encoder_layer.d_out)
126  
127 self.id_encoder = nn.ModuleList(self.id_encoder)
128  
129 # -------------- Degree embedding
130 self.degree_encoder = DiscreteEmbedding('None',
131 1,
132 1,
133 emb_dim,
134 **encoders_kwargs)
  • F841 Local variable 'd_degree' is assigned to but never used
135 d_degree = self.degree_encoder.d_out
136  
137 # -------------- GNN layers w/ bn
138 self.conv = []
139 self.batch_norms = []
140 self.mlp_vn = []
141 for i in range(self.num_layers):
142 # if i > 0 and self.vn:
  • W291 Trailing whitespace
143 # #-------------- vn msg function
  • E501 Line too long (107 > 79 characters)
144 # mlp_vn_temp = mlp(d_in_vn, kwargs['d_out_vn'][i-1], d_h[i], seed, activation_mlp, bn_mlp)
145 # self.mlp_vn.append(mlp_vn_temp)
146 # d_in_vn= kwargs['d_out_vn'][i-1]
147 # import pdb;pdb.set_trace()
148 kwargs_filter = {
149 'd_in': d_in,
150 'd_degree': 1,
151 'degree_as_tag': degree_as_tag,
152 'retain_features': retain_features[i],
153 'd_msg': d_msg,
154 'd_up': d_out,
155 'd_h': d_h,
156 'seed': seed,
157 'activation_name': activation_mlp,
158 'bn': bn_mlp,
159 'aggr': aggr,
160 'msg_kind': msg_kind,
161 'eps': 0,
162 'train_eps': train_eps,
163 'flow': flow,
164 'd_ef': d_ef[i],
165 'edge_embedding': 'bond_encoder',
166 'id_embedding': 'embedding',
167 'extend_dims': True
168 }
169  
  • E501 Line too long (110 > 79 characters)
170 use_ids = ((i > 0 and self.inject_ids) or (i == 0)) and (self.model_name == 'GSN_edge_sparse_ogb')
171  
172 if use_ids:
  • E501 Line too long (82 > 79 characters)
173 # if self.dataset_group != 'mol' and self.dataset_group != 'ppa' :
  • E203 Whitespace before ':'
174 if self.dataset_group == 'mol' or self.dataset_group == 'ppa' :
175 filter_fn = GSN_edge_sparse_ogb
  • E203 Whitespace before ':'
176 else :
177 filter_fn = GSN_sparse
178  
179 kwargs_filter['d_id'] = d_id[i] if self.inject_ids else d_id[0]
180 kwargs_filter['id_scope'] = id_scope
181 else:
182 filter_fn = MPNN_edge_sparse_ogb
183 self.conv.append(filter_fn(**kwargs_filter))
184  
185 bn_layer = nn.BatchNorm1d(d_out) if self.bn else None
186 self.batch_norms.append(bn_layer)
187  
188 d_in = d_out
189  
190 self.conv = nn.ModuleList(self.conv)
191 self.batch_norms = nn.ModuleList(self.batch_norms)
192 # if kwargs['vn']:
193 # self.mlp_vn = nn.ModuleList(self.mlp_vn)
194  
195 # -------------- Readout
196 if self.readout == 'sum':
197 self.global_pool = global_add_pool_sparse
198 elif self.readout == 'mean':
199 self.global_pool = global_mean_pool_sparse
200 else:
201 raise ValueError("Invalid graph pooling type.")
202  
203 # #-------------- Virtual node aggregation operator
204 # if self.vn:
205 # if kwargs['vn_pooling'] == 'sum':
206 # self.global_vn_pool = global_add_pool_sparse
207 # elif kwargs['vn_pooling'] == 'mean':
208 # self.global_vn_pool = global_mean_pool_sparse
209 # else:
210 # raise ValueError("Invalid graph virtual node pooling type.")
211  
212 if out_features is None:
213 self.lin_proj = None
214 else:
215 self.lin_proj = nn.Linear(d_out, out_features)
216  
217 # -------------- Activation fn (same across the network)
218  
219 self.activation = choose_activation('relu')
220  
221 return
222  
223 def forward(self, data, return_intermediate=False):
224  
  • E501 Line too long (119 > 79 characters)
225 # -------------- Code adopted from https://github.com/snap-stanford/ogb/tree/master/examples/graphproppred/mol.
  • E501 Line too long (98 > 79 characters)
226 # -------------- Modified accordingly to allow for the existence of structural identifiers
227  
228 kwargs = {}
229 kwargs['degrees'] = self.degree_encoder(data.degrees)
230  
  • E501 Line too long (91 > 79 characters)
231 # -------------- edge index, initial node features enmbedding, initial vn embedding
232 edge_index = data.edge_index
233 # if self.vn:
  • E501 Line too long (127 > 79 characters)
  • W291 Trailing whitespace
234 # vn_embedding = self.vn_encoder(torch.zeros(data.batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device))
235 x = self.input_node_encoder(data.x)
236 x_interm = [x]
237  
238 for i in range(0, len(self.conv)):
239  
240 # -------------- encode ids (different for each layer)
  • E501 Line too long (116 > 79 characters)
241 kwargs['identifiers'] = self.id_encoder[i](data.identifiers) if self.inject_ids else self.id_encoder[0](
242 data.identifiers)
243  
244 # -------------- edge features embedding (different for each layer)
245 if hasattr(data, 'edge_features'):
  • E501 Line too long (82 > 79 characters)
246 kwargs['edge_features'] = self.edge_encoder[i](data.edge_features)
247 else:
248 kwargs['edge_features'] = None
249  
250 # if self.vn:
251 # x_interm[i] = x_interm[i] + vn_embedding[data.batch]
252  
253 x = self.conv[i](x_interm[i], edge_index, **kwargs)
254  
255 x = self.batch_norms[i](x) if self.bn else x
256  
257 if i == len(self.conv) - 1:
258 x = F.dropout(x, self.dropout_features, training=self.training)
259 else:
  • E501 Line too long (96 > 79 characters)
260 x = F.dropout(self.activation(x), self.dropout_features, training=self.training)
261  
262 if self.residual:
263 x += x_interm[-1]
264  
265 x_interm.append(x)
266  
267 # if i < len(self.conv) - 1 and self.vn:
  • E501 Line too long (97 > 79 characters)
268 # vn_embedding_temp = self.global_vn_pool(x_interm[i], data.batch) + vn_embedding
269 # vn_embedding = self.mlp_vn[i](vn_embedding_temp)
270  
271 # if self.residual:
  • E501 Line too long (144 > 79 characters)
272 # vn_embedding = vn_embedding + F.dropout(self.activation(vn_embedding), self.dropout_features[i], training = self.training)
273 # else:
  • E501 Line too long (129 > 79 characters)
274 # vn_embedding = F.dropout(self.activation(vn_embedding), self.dropout_features[i], training = self.training)
275  
276 prediction = 0
277 for i in range(0, len(self.conv) + 1):
278 if self.final_projection[i]:
279 prediction += x_interm[i]
280  
281 x_global = self.global_pool(prediction, data.batch)
  • E711 Comparison to None should be 'if cond is None:'
282 if self.lin_proj == None:
283 return x_global
284 else:
285 return self.lin_proj(x_global)