⬅ models/gsn/graph_filters/utils_graph_learning.py source

1 import torch
  • F401 'torch.nn.functional as F' imported but unused
2 import torch.nn.functional as F
3 import torch.nn as nn
4 from torch_geometric.utils import degree
5  
6 from .models_misc import mlp
7 from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
  • W291 Trailing whitespace
8 from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims
9  
10  
11 def multi_class_accuracy(y_hat, y, reduction='sum'):
  • W293 Blank line contains whitespace
12
13 pred = y_hat.max(1)[1]
14 if reduction == 'sum':
15 acc = pred.eq(y).sum().float()
16 elif reduction == 'mean':
17 acc = pred.eq(y).mean().float()
18 else:
  • E501 Line too long (94 > 79 characters)
19 raise NotImplementedError('Reduction {} not currently implemented.'.format(reduction))
20 return acc
21  
22  
23 def global_add_pool_sparse(x, batch):
  • W293 Blank line contains whitespace
24
  • E265 Block comment should start with '# '
25 #-------------- global sum pooling
  • E501 Line too long (95 > 79 characters)
  • W291 Trailing whitespace
26 index = torch.stack([batch, torch.tensor(list(range(batch.shape[0])), device=x.device)], 0)
  • E501 Line too long (107 > 79 characters)
27 x_sparse = torch.sparse.FloatTensor(index, x, torch.Size([torch.max(batch)+1, x.shape[0], x.shape[1]]))
  • W293 Blank line contains whitespace
28
29 return torch.sparse.sum(x_sparse, 1).to_dense()
30  
31  
32 def global_mean_pool_sparse(x, batch):
  • W293 Blank line contains whitespace
33
  • E265 Block comment should start with '# '
34 #-------------- global average pooling
  • E501 Line too long (95 > 79 characters)
  • W291 Trailing whitespace
35 index = torch.stack([batch, torch.tensor(list(range(batch.shape[0])), device=x.device)], 0)
  • E501 Line too long (107 > 79 characters)
36 x_sparse = torch.sparse.FloatTensor(index, x, torch.Size([torch.max(batch)+1, x.shape[0], x.shape[1]]))
37  
38 graph_sizes = degree(batch).float()
  • E225 Missing whitespace around operator
39 graph_sizes[graph_sizes==0.0] = 1.0
40  
41 return torch.sparse.sum(x_sparse, 1).to_dense() / graph_sizes.unsqueeze(1)
42  
43  
44 class DiscreteEmbedding(torch.nn.Module):
45  
  • E501 Line too long (91 > 79 characters)
46 def __init__(self, encoder_name, d_in_features, d_in_encoder, d_out_encoder, **kwargs):
47  
48 super(DiscreteEmbedding, self).__init__()
  • W293 Blank line contains whitespace
49
  • E265 Block comment should start with '# '
50 #-------------- various different embedding layers
51 kwargs['init'] = None if 'init' not in kwargs else kwargs['init']
  • W293 Blank line contains whitespace
52
53 self.encoder_name = encoder_name
  • W291 Trailing whitespace
54 # d_in_features: input feature size (e.g. if already one hot encoded),
  • E501 Line too long (99 > 79 characters)
55 # d_in_encoder: number of unique values that will be encoded (size of embedding vocabulary)
  • W293 Blank line contains whitespace
56
  • E265 Block comment should start with '# '
57 #-------------- fill embedding with zeros
58 if encoder_name == 'zero_encoder':
59 self.encoder = zero_encoder(d_out_encoder)
60 d_out = d_out_encoder
61  
  • E265 Block comment should start with '# '
62 #-------------- linear pojection
63 elif encoder_name == 'linear':
64 self.encoder = nn.Linear(d_in_features, d_out_encoder, bias=True)
65 d_out = d_out_encoder
66  
  • E265 Block comment should start with '# '
67 #-------------- mlp
68 elif encoder_name == 'mlp':
69 self.encoder = mlp(d_in_features,
  • W291 Trailing whitespace
70 d_out_encoder,
71 d_out_encoder,
72 kwargs['seed'],
73 kwargs['activation_mlp'],
74 kwargs['bn_mlp'])
75 d_out = d_out_encoder
76  
  • E265 Block comment should start with '# '
77 #-------------- multi hot encoding of categorical data
78 elif encoder_name == 'one_hot_encoder':
79 self.encoder = one_hot_encoder(d_in_encoder)
80 d_out = sum(d_in_encoder)
81  
  • E265 Block comment should start with '# '
  • E501 Line too long (107 > 79 characters)
82 #-------------- embedding of categorical data (linear projection without bias of one hot encodings)
83 elif encoder_name == 'embedding':
  • E501 Line too long (103 > 79 characters)
84 self.encoder = multi_embedding(d_in_encoder, d_out_encoder, kwargs['aggr'], kwargs['init'])
85 if kwargs['aggr'] == 'concat':
86 d_out = len(d_in_encoder) * d_out_encoder
87 else:
88 d_out = d_out_encoder
  • W293 Blank line contains whitespace
89
  • E265 Block comment should start with '# '
90 #-------------- for ogb: multi hot encoding of node features
91 elif encoder_name == 'atom_one_hot_encoder':
  • E501 Line too long (131 > 79 characters)
92 full_atom_feature_dims = get_atom_feature_dims() if kwargs['features_scope'] == 'full' else get_atom_feature_dims()[:2]
93 self.encoder = one_hot_encoder(full_atom_feature_dims)
94 d_out = sum(full_atom_feature_dims)
  • W293 Blank line contains whitespace
95
  • E265 Block comment should start with '# '
96 #-------------- for ogb: multi hot encoding of edge features
  • E221 Multiple spaces before operator
97 elif encoder_name == 'bond_one_hot_encoder':
  • E501 Line too long (132 > 79 characters)
  • E271 Multiple spaces after keyword
98 full_bond_feature_dims = get_bond_feature_dims() if kwargs['features_scope'] == 'full' else get_bond_feature_dims()[:2]
  • E221 Multiple spaces before operator
99 self.encoder = one_hot_encoder(full_bond_feature_dims)
100 d_out = sum(full_bond_feature_dims)
  • W293 Blank line contains whitespace
101
  • E265 Block comment should start with '# '
102 #-------------- for ogb: embedding of node features
103 elif encoder_name == 'atom_encoder':
  • E221 Multiple spaces before operator
104 self.encoder = AtomEncoder(d_out_encoder)
105 d_out = d_out_encoder
106  
  • E265 Block comment should start with '# '
107 #-------------- for ogb: embedding of edge features
  • E221 Multiple spaces before operator
108 elif encoder_name == 'bond_encoder':
  • E221 Multiple spaces before operator
  • E251 Unexpected spaces around keyword / parameter equals (in 2 places)
109 self.encoder = BondEncoder(emb_dim = d_out_encoder)
110 d_out = d_out_encoder
111  
  • E265 Block comment should start with '# '
112 #-------------- no embedding, use as is
113 elif encoder_name == 'None':
  • E221 Multiple spaces before operator
114 self.encoder = None
115 d_out = d_in_features
116  
117 else:
  • E501 Line too long (100 > 79 characters)
118 raise NotImplementedError('Encoder {} is not currently supported.'.format(encoder_name))
  • W293 Blank line contains whitespace
119
120 self.d_out = d_out
  • W293 Blank line contains whitespace
121
122 return
123  
124 def forward(self, x):
  • W293 Blank line contains whitespace
125
126 x = x.unsqueeze(-1) if x.dim() == 1 else x
127 if self.encoder is not None:
  • E222 Multiple spaces after operator
  • E501 Line too long (103 > 79 characters)
128 x = x.float() if self.encoder_name == 'linear' or self.encoder_name == 'mlp' else x.long()
129 return self.encoder(x)
130 else:
  • W291 Trailing whitespace
131 return x.float()
132  
133  
134 class multi_embedding(torch.nn.Module):
  • W293 Blank line contains whitespace
135
  • E251 Unexpected spaces around keyword / parameter equals (in 2 places)
136 def __init__(self, d_in, d_out, aggr = 'concat', init=None):
  • W293 Blank line contains whitespace
137
138 super(multi_embedding, self).__init__()
  • W293 Blank line contains whitespace
139
  • E265 Block comment should start with '# '
  • E501 Line too long (123 > 79 characters)
140 #-------------- embedding of multiple categorical features. Summation or concatenation of the embeddings is allowed
  • W293 Blank line contains whitespace
141
142 self.d_in = d_in
143 self.aggr = aggr
144 self.encoder = []
145 for i in range(len(d_in)):
146 self.encoder.append(nn.Embedding(d_in[i], d_out))
147 if init == 'zeros':
148 print('### INITIALIZING EMBEDDING TO ZERO ###')
149 torch.nn.init.constant_(self.encoder[i].weight.data, 0)
150 else:
151 torch.nn.init.xavier_uniform_(self.encoder[-1].weight.data)
  • W291 Trailing whitespace
152 self.encoder = nn.ModuleList(self.encoder)
  • W293 Blank line contains whitespace
153
  • W291 Trailing whitespace
154 return
155  
156 def forward(self, tensor):
  • W293 Blank line contains whitespace
157
158 for i in range(tensor.shape[1]):
  • E231 Missing whitespace after ','
159 embedding_i = self.encoder[i](tensor[:,i])
160 if self.aggr == 'concat':
  • F821 Undefined name 'embedding'
  • E231 Missing whitespace after ','
  • E225 Missing whitespace around operator
  • E501 Line too long (89 > 79 characters)
161 embedding = torch.cat((embedding, embedding_i),1) if i>0 else embedding_i
162 elif self.aggr == 'sum':
  • E225 Missing whitespace around operator
163 embedding = embedding + embedding_i if i>0 else embedding_i
164 else:
  • E501 Line too long (121 > 79 characters)
165 raise NotImplementedError('multi embedding aggregation {} is not currently supported.'.format(self.aggr))
  • W293 Blank line contains whitespace
166
167 return embedding
168  
169  
170 class one_hot_encoder(torch.nn.Module):
  • W293 Blank line contains whitespace
171
172 def __init__(self, d_in):
  • W293 Blank line contains whitespace
173
174 super(one_hot_encoder, self).__init__()
  • W293 Blank line contains whitespace
175
176 self.d_in = d_in
  • W293 Blank line contains whitespace
177
  • W291 Trailing whitespace
178 return
179  
180 def forward(self, tensor):
  • W293 Blank line contains whitespace
181
182 for i in range(tensor.shape[1]):
  • E501 Line too long (89 > 79 characters)
183 onehot_i = torch.zeros((tensor.shape[0], self.d_in[i]), device=tensor.device)
  • E231 Missing whitespace after ','
184 onehot_i.scatter_(1, tensor[:,i:i+1], 1)
  • F821 Undefined name 'onehot'
  • E225 Missing whitespace around operator
185 onehot = torch.cat((onehot, onehot_i), 1) if i>0 else onehot_i
  • W293 Blank line contains whitespace
186
187 return onehot
  • W293 Blank line contains whitespace
188
189 def __repr__(self):
190 return '{}({})'.format(self.__class__.__name__, self.d_in)
  • W293 Blank line contains whitespace
191
192  
193 class zero_encoder(torch.nn.Module):
  • W293 Blank line contains whitespace
194
195 def __init__(self, d_out):
  • W293 Blank line contains whitespace
196
197 super(zero_encoder, self).__init__()
  • W293 Blank line contains whitespace
198
199 self.d_out = d_out
  • W293 Blank line contains whitespace
200
  • W291 Trailing whitespace
201 return
202  
203 def forward(self, tensor):
  • W293 Blank line contains whitespace
204
205 return torch.zeros((tensor.shape[0], self.d_out), device=tensor.device)
  • W293 Blank line contains whitespace
206
207 def __repr__(self):
  • W291 Trailing whitespace
208 return '{}({})'.format(self.__class__.__name__, self.d_out)
209  
210  
211 class central_encoder(nn.Module):
  • W293 Blank line contains whitespace
212
213 def __init__(self, nb_encoder, d_ef, extend=True):
  • W293 Blank line contains whitespace
214
215 super(central_encoder, self).__init__()
  • W293 Blank line contains whitespace
216
  • E265 Block comment should start with '# '
217 #-------------- For the neighbor aggregation: central node embedding
  • E265 Block comment should start with '# '
  • E501 Line too long (92 > 79 characters)
218 #-------------- This is a way to create a dummy variable that represents self loops.
  • E265 Block comment should start with '# '
219 #-------------- Useful when working with edge features or GSN-e
  • E265 Block comment should start with '# '
  • E501 Line too long (119 > 79 characters)
220 #-------------- Two ways are allowed: extra dummy variable (one hot or embedding) or a vector filled with zeros
  • W293 Blank line contains whitespace
221
222 self.extend = extend
223 self.nb_encoder = nb_encoder
  • W293 Blank line contains whitespace
224
225 if self.extend:
226 print('##### EXTENDING EDGE FEATURE DIMENSIONS #####')
  • W293 Blank line contains whitespace
227
228 if 'one_hot_encoder' in nb_encoder:
229 if self.extend:
  • E501 Line too long (86 > 79 characters)
230 self.encoder = DiscreteEmbedding('one_hot_encoder', 1, [d_ef+1], None)
231 self.d_out = d_ef+1
232 else:
233 self.d_out = d_ef
234 else:
235 self.d_out = d_ef
236 if self.extend:
  • E501 Line too long (91 > 79 characters)
237 self.encoder = DiscreteEmbedding('embedding', None, [1], d_ef, aggr='sum')
238 else:
239 pass
  • W293 Blank line contains whitespace
240
241 return
242  
243 def forward(self, x_nb, num_nodes):
  • W293 Blank line contains whitespace
244
245 if 'one_hot_encoder' in self.nb_encoder:
246 if self.extend:
  • E501 Line too long (84 > 79 characters)
247 zero_extension = torch.zeros((x_nb.shape[0], 1), device=x_nb.device)
248 x_nb = torch.cat((zero_extension, x_nb), -1)
  • E231 Missing whitespace after ','
  • E501 Line too long (81 > 79 characters)
249 x_central = torch.zeros((num_nodes,1), device=x_nb.device).long()
250 x_central = self.encoder(x_central)
251 else:
  • E501 Line too long (84 > 79 characters)
252 x_central = torch.zeros((num_nodes, self.d_out), device=x_nb.device)
253 else:
254 if self.extend:
  • E231 Missing whitespace after ','
  • E501 Line too long (81 > 79 characters)
255 x_central = torch.zeros((num_nodes,1), device=x_nb.device).long()
256 x_central = self.encoder(x_central)
257 else:
  • E501 Line too long (84 > 79 characters)
258 x_central = torch.zeros((num_nodes, self.d_out), device=x_nb.device)
  • W293 Blank line contains whitespace
259
260 return x_central, x_nb
  • W293 Blank line contains whitespace
261
  • W293 Blank line contains whitespace
262