1 import torch
            
            
               
               
                  - 
                     
                        F401
                     
                     'torch.nn.functional as F' imported but unused
 
               
               
2 import torch.nn.functional as F
             
            3 import torch.nn as nn
            
            4 from torch_geometric.utils import degree
            
            5  
            
            6 from .models_misc import mlp
            
            7 from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
            
            
               8 from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims 
             
            9  
            
            10  
            
            11 def multi_class_accuracy(y_hat, y, reduction='sum'):
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
12     
             
            13     pred = y_hat.max(1)[1]
            
            14     if reduction == 'sum':
            
            15         acc = pred.eq(y).sum().float()
            
            16     elif reduction == 'mean':
            
            17         acc = pred.eq(y).mean().float()
            
            18     else:
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (94 > 79 characters)
 
               
               
19         raise NotImplementedError('Reduction {} not currently implemented.'.format(reduction))
             
            20     return acc
            
            21  
            
            22  
            
            23 def global_add_pool_sparse(x, batch):
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
24     
             
            
               
               
                  - 
                     
                        E265
                     
                     Block comment should start with '# '
 
               
               
25     #-------------- global sum pooling
             
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (95 > 79 characters)
 
               
                  - 
                     
                        W291
                     
                     Trailing whitespace
 
               
               
26     index = torch.stack([batch, torch.tensor(list(range(batch.shape[0])), device=x.device)], 0)  
             
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (107 > 79 characters)
 
               
               
27     x_sparse = torch.sparse.FloatTensor(index, x, torch.Size([torch.max(batch)+1, x.shape[0], x.shape[1]]))
             
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
28         
             
            29     return torch.sparse.sum(x_sparse, 1).to_dense()
            
            30  
            
            31  
            
            32 def global_mean_pool_sparse(x, batch):
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
33     
             
            
               
               
                  - 
                     
                        E265
                     
                     Block comment should start with '# '
 
               
               
34     #-------------- global average pooling
             
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (95 > 79 characters)
 
               
                  - 
                     
                        W291
                     
                     Trailing whitespace
 
               
               
35     index = torch.stack([batch, torch.tensor(list(range(batch.shape[0])), device=x.device)], 0)  
             
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (107 > 79 characters)
 
               
               
36     x_sparse = torch.sparse.FloatTensor(index, x, torch.Size([torch.max(batch)+1, x.shape[0], x.shape[1]]))
             
            37  
            
            38     graph_sizes = degree(batch).float()
            
            
               
               
                  - 
                     
                        E225
                     
                     Missing whitespace around operator
 
               
               
39     graph_sizes[graph_sizes==0.0] = 1.0
             
            40  
            
            41     return torch.sparse.sum(x_sparse, 1).to_dense() / graph_sizes.unsqueeze(1)
            
            42  
            
            43  
            
            44 class DiscreteEmbedding(torch.nn.Module):
            
            45  
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (91 > 79 characters)
 
               
               
46     def __init__(self, encoder_name, d_in_features, d_in_encoder, d_out_encoder, **kwargs):
             
            47  
            
            48         super(DiscreteEmbedding, self).__init__()
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
49         
             
            
               
               
                  - 
                     
                        E265
                     
                     Block comment should start with '# '
 
               
               
50         #-------------- various different embedding layers
             
            51         kwargs['init'] = None if 'init' not in kwargs else kwargs['init']
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
52     
             
            53         self.encoder_name = encoder_name
            
            
               54         # d_in_features: input feature size (e.g. if already one hot encoded), 
             
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (99 > 79 characters)
 
               
               
55         # d_in_encoder: number of unique values that will be encoded (size of embedding vocabulary)
             
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
56         
             
            
               
               
                  - 
                     
                        E265
                     
                     Block comment should start with '# '
 
               
               
57         #-------------- fill embedding with zeros
             
            58         if encoder_name == 'zero_encoder':
            
            59             self.encoder = zero_encoder(d_out_encoder)
            
            60             d_out = d_out_encoder
            
            61  
            
            
               
               
                  - 
                     
                        E265
                     
                     Block comment should start with '# '
 
               
               
62         #-------------- linear pojection
             
            63         elif encoder_name == 'linear':
            
            64             self.encoder = nn.Linear(d_in_features,  d_out_encoder, bias=True)
            
            65             d_out = d_out_encoder
            
            66  
            
            
               
               
                  - 
                     
                        E265
                     
                     Block comment should start with '# '
 
               
               
67         #-------------- mlp
             
            68         elif encoder_name == 'mlp':
            
            69             self.encoder = mlp(d_in_features,
            
            
            71                                d_out_encoder,
            
            72                                kwargs['seed'],
            
            73                                kwargs['activation_mlp'],
            
            74                                kwargs['bn_mlp'])
            
            75             d_out = d_out_encoder
            
            76  
            
            
               
               
                  - 
                     
                        E265
                     
                     Block comment should start with '# '
 
               
               
77         #-------------- multi hot encoding of categorical data
             
            78         elif encoder_name == 'one_hot_encoder':
            
            79             self.encoder = one_hot_encoder(d_in_encoder)
            
            80             d_out = sum(d_in_encoder)
            
            81  
            
            
               
               
                  - 
                     
                        E265
                     
                     Block comment should start with '# '
 
               
                  - 
                     
                        E501
                     
                     Line too long (107 > 79 characters)
 
               
               
82         #-------------- embedding of categorical data (linear projection without bias of one hot encodings)
             
            83         elif encoder_name == 'embedding':
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (103 > 79 characters)
 
               
               
84             self.encoder = multi_embedding(d_in_encoder, d_out_encoder, kwargs['aggr'], kwargs['init'])
             
            85             if kwargs['aggr'] == 'concat':
            
            86                 d_out = len(d_in_encoder) * d_out_encoder
            
            87             else:
            
            88                 d_out = d_out_encoder
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
89                 
             
            
               
               
                  - 
                     
                        E265
                     
                     Block comment should start with '# '
 
               
               
90         #-------------- for ogb: multi hot encoding of node features
             
            91         elif encoder_name == 'atom_one_hot_encoder':
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (131 > 79 characters)
 
               
               
92             full_atom_feature_dims = get_atom_feature_dims() if kwargs['features_scope'] == 'full' else get_atom_feature_dims()[:2]
             
            93             self.encoder = one_hot_encoder(full_atom_feature_dims)
            
            94             d_out = sum(full_atom_feature_dims)
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
95         
             
            
               
               
                  - 
                     
                        E265
                     
                     Block comment should start with '# '
 
               
               
96         #-------------- for ogb: multi hot encoding of edge features
             
            
               
               
                  - 
                     
                        E221
                     
                     Multiple spaces before operator
 
               
               
97         elif encoder_name  == 'bond_one_hot_encoder':
             
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (132 > 79 characters)
 
               
                  - 
                     
                        E271
                     
                     Multiple spaces after keyword
 
               
               
98             full_bond_feature_dims = get_bond_feature_dims() if kwargs['features_scope'] == 'full' else  get_bond_feature_dims()[:2]
             
            
               
               
                  - 
                     
                        E221
                     
                     Multiple spaces before operator
 
               
               
99             self.encoder  = one_hot_encoder(full_bond_feature_dims)
             
            100             d_out = sum(full_bond_feature_dims)
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
101                 
             
            
               
               
                  - 
                     
                        E265
                     
                     Block comment should start with '# '
 
               
               
102         #-------------- for ogb: embedding of node features
             
            103         elif encoder_name == 'atom_encoder':
            
            
               
               
                  - 
                     
                        E221
                     
                     Multiple spaces before operator
 
               
               
104             self.encoder  = AtomEncoder(d_out_encoder)
             
            105             d_out = d_out_encoder
            
            106  
            
            
               
               
                  - 
                     
                        E265
                     
                     Block comment should start with '# '
 
               
               
107         #-------------- for ogb: embedding of edge features
             
            
               
               
                  - 
                     
                        E221
                     
                     Multiple spaces before operator
 
               
               
108         elif encoder_name  == 'bond_encoder':
             
            
               
               
                  - 
                     
                        E221
                     
                     Multiple spaces before operator
 
               
                  - 
                     
                        E251
                     
                     Unexpected spaces around keyword / parameter equals (in 2 places)
 
               
               
109             self.encoder  = BondEncoder(emb_dim = d_out_encoder)
             
            110             d_out = d_out_encoder
            
            111  
            
            
               
               
                  - 
                     
                        E265
                     
                     Block comment should start with '# '
 
               
               
112         #-------------- no embedding, use as is
             
            113         elif encoder_name == 'None':
            
            
               
               
                  - 
                     
                        E221
                     
                     Multiple spaces before operator
 
               
               
114             self.encoder  = None
             
            115             d_out = d_in_features
            
            116  
            
            117         else:
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (100 > 79 characters)
 
               
               
118             raise NotImplementedError('Encoder {} is not currently supported.'.format(encoder_name))
             
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
119             
             
            120         self.d_out = d_out
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
121         
             
            122         return
            
            123  
            
            124     def forward(self, x):
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
125         
             
            126         x = x.unsqueeze(-1) if x.dim() == 1 else x
            
            127         if self.encoder is not None:
            
            
               
               
                  - 
                     
                        E222
                     
                     Multiple spaces after operator
 
               
                  - 
                     
                        E501
                     
                     Line too long (103 > 79 characters)
 
               
               
128             x = x.float() if self.encoder_name ==  'linear' or self.encoder_name == 'mlp' else x.long()
             
            129             return self.encoder(x)
            
            130         else:
            
            
            132  
            
            133  
            
            134 class multi_embedding(torch.nn.Module):
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
135     
             
            
               
               
                  - 
                     
                        E251
                     
                     Unexpected spaces around keyword / parameter equals (in 2 places)
 
               
               
136     def __init__(self, d_in, d_out, aggr = 'concat', init=None):
             
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
137         
             
            138         super(multi_embedding, self).__init__()
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
139         
             
            
               
               
                  - 
                     
                        E265
                     
                     Block comment should start with '# '
 
               
                  - 
                     
                        E501
                     
                     Line too long (123 > 79 characters)
 
               
               
140         #-------------- embedding of multiple categorical features. Summation or concatenation of the embeddings is allowed
             
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
141         
             
            142         self.d_in = d_in
            
            143         self.aggr = aggr
            
            144         self.encoder = []
            
            145         for i in range(len(d_in)):
            
            146             self.encoder.append(nn.Embedding(d_in[i], d_out))
            
            147             if init == 'zeros':
            
            148                 print('### INITIALIZING EMBEDDING TO ZERO ###')
            
            149                 torch.nn.init.constant_(self.encoder[i].weight.data, 0)
            
            150             else:
            
            151                 torch.nn.init.xavier_uniform_(self.encoder[-1].weight.data)
            
            
               152         self.encoder = nn.ModuleList(self.encoder)   
             
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
153         
             
            
            155  
            
            156     def forward(self, tensor):
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
157         
             
            158         for i in range(tensor.shape[1]):
            
            
               
               
                  - 
                     
                        E231
                     
                     Missing whitespace after ','
 
               
               
159             embedding_i = self.encoder[i](tensor[:,i])
             
            160             if self.aggr == 'concat':
            
            
               
               
                  - 
                     
                        F821
                     
                     Undefined name 'embedding'
 
               
                  - 
                     
                        E231
                     
                     Missing whitespace after ','
 
               
                  - 
                     
                        E225
                     
                     Missing whitespace around operator
 
               
                  - 
                     
                        E501
                     
                     Line too long (89 > 79 characters)
 
               
               
161                 embedding = torch.cat((embedding, embedding_i),1) if i>0 else embedding_i
             
            162             elif self.aggr == 'sum':
            
            
               
               
                  - 
                     
                        E225
                     
                     Missing whitespace around operator
 
               
               
163                 embedding = embedding + embedding_i if i>0 else embedding_i
             
            164             else:
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (121 > 79 characters)
 
               
               
165                 raise NotImplementedError('multi embedding aggregation {} is not currently supported.'.format(self.aggr))
             
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
166         
             
            167         return embedding
            
            168  
            
            169  
            
            170 class one_hot_encoder(torch.nn.Module):
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
171     
             
            172     def __init__(self, d_in):
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
173         
             
            174         super(one_hot_encoder, self).__init__()
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
175         
             
            176         self.d_in = d_in
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
177         
             
            
            179  
            
            180     def forward(self, tensor):
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
181         
             
            182         for i in range(tensor.shape[1]):
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (89 > 79 characters)
 
               
               
183             onehot_i = torch.zeros((tensor.shape[0], self.d_in[i]), device=tensor.device)
             
            
               
               
                  - 
                     
                        E231
                     
                     Missing whitespace after ','
 
               
               
184             onehot_i.scatter_(1, tensor[:,i:i+1], 1)
             
            
               
               
                  - 
                     
                        F821
                     
                     Undefined name 'onehot'
 
               
                  - 
                     
                        E225
                     
                     Missing whitespace around operator
 
               
               
185             onehot = torch.cat((onehot, onehot_i), 1) if i>0 else onehot_i
             
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
186         
             
            187         return onehot
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
188     
             
            189     def __repr__(self):
            
            190         return '{}({})'.format(self.__class__.__name__, self.d_in)
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
191     
             
            192  
            
            193 class zero_encoder(torch.nn.Module):
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
194     
             
            195     def __init__(self, d_out):
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
196         
             
            197         super(zero_encoder, self).__init__()
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
198         
             
            199         self.d_out = d_out
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
200         
             
            
            202  
            
            203     def forward(self, tensor):
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
204         
             
            205         return torch.zeros((tensor.shape[0], self.d_out), device=tensor.device)
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
206     
             
            207     def __repr__(self):
            
            
               208         return '{}({})'.format(self.__class__.__name__, self.d_out) 
             
            209  
            
            210  
            
            211 class central_encoder(nn.Module):
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
212     
             
            213     def __init__(self, nb_encoder, d_ef, extend=True):
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
214         
             
            215         super(central_encoder, self).__init__()
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
216         
             
            
               
               
                  - 
                     
                        E265
                     
                     Block comment should start with '# '
 
               
               
217         #-------------- For the neighbor aggregation: central node embedding
             
            
               
               
                  - 
                     
                        E265
                     
                     Block comment should start with '# '
 
               
                  - 
                     
                        E501
                     
                     Line too long (92 > 79 characters)
 
               
               
218         #-------------- This is a way to create a dummy variable that represents self loops.
             
            
               
               
                  - 
                     
                        E265
                     
                     Block comment should start with '# '
 
               
               
219         #-------------- Useful when working with edge features or GSN-e
             
            
               
               
                  - 
                     
                        E265
                     
                     Block comment should start with '# '
 
               
                  - 
                     
                        E501
                     
                     Line too long (119 > 79 characters)
 
               
               
220         #-------------- Two ways are allowed: extra dummy variable (one hot or embedding) or a vector filled with zeros
             
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
221         
             
            222         self.extend = extend
            
            223         self.nb_encoder = nb_encoder
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
224         
             
            225         if self.extend:
            
            226             print('##### EXTENDING EDGE FEATURE DIMENSIONS #####')
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
227         
             
            228         if 'one_hot_encoder' in nb_encoder:
            
            229             if self.extend:
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (86 > 79 characters)
 
               
               
230                 self.encoder = DiscreteEmbedding('one_hot_encoder', 1, [d_ef+1], None)
             
            231                 self.d_out = d_ef+1
            
            232             else:
            
            233                 self.d_out = d_ef
            
            234         else:
            
            235             self.d_out = d_ef
            
            236             if self.extend:
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (91 > 79 characters)
 
               
               
237                 self.encoder = DiscreteEmbedding('embedding',  None, [1], d_ef, aggr='sum')
             
            238             else:
            
            239                 pass
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
240             
             
            241         return
            
            242  
            
            243     def forward(self, x_nb, num_nodes):
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
244         
             
            245         if 'one_hot_encoder' in self.nb_encoder:
            
            246             if self.extend:
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (84 > 79 characters)
 
               
               
247                 zero_extension = torch.zeros((x_nb.shape[0], 1), device=x_nb.device)
             
            248                 x_nb = torch.cat((zero_extension, x_nb), -1)
            
            
               
               
                  - 
                     
                        E231
                     
                     Missing whitespace after ','
 
               
                  - 
                     
                        E501
                     
                     Line too long (81 > 79 characters)
 
               
               
249                 x_central = torch.zeros((num_nodes,1), device=x_nb.device).long()
             
            250                 x_central = self.encoder(x_central)
            
            251             else:
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (84 > 79 characters)
 
               
               
252                 x_central = torch.zeros((num_nodes, self.d_out), device=x_nb.device)
             
            253         else:
            
            254             if self.extend:
            
            
               
               
                  - 
                     
                        E231
                     
                     Missing whitespace after ','
 
               
                  - 
                     
                        E501
                     
                     Line too long (81 > 79 characters)
 
               
               
255                 x_central = torch.zeros((num_nodes,1), device=x_nb.device).long()
             
            256                 x_central = self.encoder(x_central)
            
            257             else:
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (84 > 79 characters)
 
               
               
258                 x_central = torch.zeros((num_nodes, self.d_out), device=x_nb.device)
             
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
259             
             
            260         return x_central, x_nb
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
261     
             
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
262