⬅ models/mlp.py source

1 import torch
2 import torch.nn.functional as F
3 from ogb.graphproppred.mol_encoder import AtomEncoder
4 from torch_geometric.nn import global_mean_pool
5  
  • E302 Expected 2 blank lines, found 1
6 class MLP(torch.nn.Module):
  • E501 Line too long (118 > 79 characters)
7 def __init__(self, gnn_type, dataset_group, num_tasks=128, num_layers=5, emb_dim=300, dropout=0.5, is_pooled=True,
8 **model_kwargs):
9 """
10 Args:
  • E501 Line too long (109 > 79 characters)
11 - num_tasks (int): number of binary label tasks. default to 128 (number of tasks of ogbg-molpcba)
12 - num_layers (int): number of message passing layers of GNN
13 - emb_dim (int): dimensionality of hidden channels
14 - dropout (float): dropout ratio applied to hidden channels
15 """
16 self.dataset_group = dataset_group
17  
18 super(MLP, self).__init__()
19  
20 self.num_layers = num_layers
21 self.dropout = dropout
22 self.emb_dim = emb_dim
23 self.num_tasks = num_tasks
24 self.is_pooled = is_pooled
25  
26 if num_tasks is None:
27 self.d_out = self.emb_dim
28 else:
29 self.d_out = self.num_tasks
30  
31 if self.num_layers < 2:
32 raise ValueError("Number of GNN layers must be greater than 1.")
33  
34  
  • E303 Too many blank lines (2)
  • E501 Line too long (90 > 79 characters)
35 ##################################################################################
36 if self.dataset_group == 'mol':
37 self.node_encoder = AtomEncoder(emb_dim)
38 elif self.dataset_group == 'ppa':
  • E261 At least two spaces before inline comment
  • E501 Line too long (93 > 79 characters)
39 self.node_encoder = torch.nn.Embedding(1, emb_dim) # uniform input node embedding
40 elif self.dataset_group == 'RotatedMNIST':
41 self.node_encoder = torch.nn.Linear(1, emb_dim)
  • E203 Whitespace before ':'
42 elif self.dataset_group == 'ColoredMNIST' :
43 self.node_encoder = torch.nn.Linear(2, emb_dim)
44 # self.node_encoder_cate = torch.nn.Embedding(8, emb_dim)
  • E203 Whitespace before ':'
45 elif self.dataset_group == 'SBM' :
46 self.node_encoder = torch.nn.Embedding(8, emb_dim)
47 elif self.dataset_group == 'UPFD':
48 self.node_encoder = torch.nn.Embedding(8, emb_dim)
49 else:
50 raise NotImplementedError
  • E265 Block comment should start with '# '
51 ###List of GNNs
52 self.fcs = torch.nn.ModuleList()
53 self.batch_norms = torch.nn.ModuleList()
54  
55 for layer in range(num_layers):
56 self.fcs.append(torch.nn.Linear(emb_dim, emb_dim))
57 self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
  • E501 Line too long (90 > 79 characters)
58 ##################################################################################
59  
60 self.pool = global_mean_pool
61 # Pooling function to generate whole-graph embeddings
62 if num_tasks is None:
63 self.graph_pred_linear = None
64 else:
  • E501 Line too long (82 > 79 characters)
65 self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_tasks)
66  
67 def forward(self, batched_data):
68 x = batched_data.x
69 # if self.dataset_group == 'ColoredMNIST' :
  • E501 Line too long (106 > 79 characters)
70 # # x = self.node_encoder(x[:, :2]) + self.node_encoder_cate(x[:, 2:].to(torch.int).squeeze())
71 # x = self.node_encoder(x[:, :2])
72 # else :
73 x = self.node_encoder(x)
  • E203 Whitespace before ':'
74 for i in range(self.num_layers) :
75 x = self.fcs[i](x)
76 x = self.batch_norms[i](x)
77 if i == self.num_layers - 1:
78 x = F.dropout(x, self.dropout, training=self.training)
79 else:
80 x = F.dropout(F.relu(x), self.dropout, training=self.training)
81 x = self.pool(x, batched_data.batch)
82  
83 if self.graph_pred_linear is None:
84 return x
85 else:
86 return self.graph_pred_linear(x)