⬅ models/three_wl.py source

1 import torch
2 import torch.nn as nn
3 import torch.nn.functional as F
4  
5 """
6 3WLGNN / ThreeWLGNN
7 Provably Powerful Graph Networks (Maron et al., 2019)
8 https://papers.nips.cc/paper/8488-provably-powerful-graph-networks.pdf
9  
  • E501 Line too long (86 > 79 characters)
10 CODE adapted from https://github.com/hadarser/ProvablyPowerfulGraphNetworks_torch/
11 """
12  
13  
14 class ThreeWLGNNNet(nn.Module):
  • E501 Line too long (104 > 79 characters)
15 def __init__(self, gnn_type, num_tasks, three_wl_in_dim, n_layers=3, depth_of_mlp=2, hidden_dim=128,
16 residual=False, **model_kwargs):
17 assert gnn_type == '3wlgnn'
18 super(ThreeWLGNNNet, self).__init__()
19 self.in_dim_node = three_wl_in_dim
20 self.residual = residual
21  
22 block_features = [hidden_dim] * n_layers # L here is the block number
  • E501 Line too long (87 > 79 characters)
23 original_features_num = self.in_dim_node + 1 # Number of features of the input
24  
25 # sequential mlp blocks
26 last_layer_features = original_features_num
27 self.reg_blocks = nn.ModuleList()
28 for layer, next_layer_features in enumerate(block_features):
  • E501 Line too long (107 > 79 characters)
29 mlp_block = RegularBlock(depth_of_mlp, last_layer_features, next_layer_features, self.residual)
30 self.reg_blocks.append(mlp_block)
31 last_layer_features = next_layer_features
32  
33 self.fc_layers = nn.ModuleList()
34 for output_features in block_features:
  • E501 Line too long (114 > 79 characters)
35 # each block's output will be pooled (thus have 2*output_features), and pass through a fully connected
  • E501 Line too long (83 > 79 characters)
36 fc = FullyConnected(2 * output_features, num_tasks, activation_fn=None)
37 self.fc_layers.append(fc)
38  
39 def forward(self, x):
40 scores = torch.tensor(0, device=x.device, dtype=x.dtype)
41 for i, block in enumerate(self.reg_blocks):
42 x = block(x)
43 scores = self.fc_layers[i](diag_offdiag_maxpool(x)) + scores
44 return scores
45  
46  
47 # class ThreeWLGNNEdgeNet(nn.Module):
  • E501 Line too long (102 > 79 characters)
48 # def __init__(self, gnn_type, num_tasks, feature_dim, n_layers=3, depth_of_mlp=2, hidden_dim=128,
49 # residual=False, **model_kwargs):
50 #
51 # assert gnn_type == '3wlgnn-edge'
52 # super(ThreeWLGNNEdgeNet, self).__init__()
53 #
54 # self.in_dim_node = feature_dim
55 # self.residual = residual
56 #
  • E501 Line too long (80 > 79 characters)
57 # block_features = [hidden_dim] * n_layers # L here is the block number
  • E501 Line too long (110 > 79 characters)
58 # original_features_num = self.in_dim_node + self.num_bond_type + 1 # Number of features of the input
59 #
60 # # sequential mlp blocks
61 # last_layer_features = original_features_num
62 # self.reg_blocks = nn.ModuleList()
63 # for layer, next_layer_features in enumerate(block_features):
  • E501 Line too long (109 > 79 characters)
64 # mlp_block = RegularBlock(depth_of_mlp, last_layer_features, next_layer_features, self.residual)
65 # self.reg_blocks.append(mlp_block)
66 # last_layer_features = next_layer_features
67 #
68 # self.fc_layers = nn.ModuleList()
69 # for output_features in block_features:
  • E501 Line too long (116 > 79 characters)
70 # # each block's output will be pooled (thus have 2*output_features), and pass through a fully connected
  • E501 Line too long (85 > 79 characters)
71 # fc = FullyConnected(2 * output_features, num_tasks, activation_fn=None)
72 # self.fc_layers.append(fc)
73 #
74 # def forward(self, x_no_edge_feat, x_with_edge_feat):
75 # x = x_with_edge_feat
76 #
77 # scores = torch.tensor(0, device=x.device, dtype=x.dtype)
78 # for i, block in enumerate(self.reg_blocks):
79 # x = block(x)
80 # scores = self.fc_layers[i](diag_offdiag_maxpool(x)) + scores
81 # return scores
82  
83  
84 class MLPReadout(nn.Module):
85  
86 def __init__(self, input_dim, output_dim, L=2): # L=nb_hidden_layers
87 super().__init__()
  • E501 Line too long (113 > 79 characters)
  • E741 Ambiguous variable name 'l'
88 list_FC_layers = [nn.Linear(input_dim // 2 ** l, input_dim // 2 ** (l + 1), bias=True) for l in range(L)]
  • E501 Line too long (84 > 79 characters)
89 list_FC_layers.append(nn.Linear(input_dim // 2 ** L, output_dim, bias=True))
90 self.FC_layers = nn.ModuleList(list_FC_layers)
91 self.L = L
92  
93 def forward(self, x):
94 y = x
  • E741 Ambiguous variable name 'l'
95 for l in range(self.L):
96 y = self.FC_layers[l](y)
97 y = F.relu(y)
98 y = self.FC_layers[self.L](y)
99 return y
100  
101  
102 class RegularBlock(nn.Module):
103 """
104 Imputs: N x input_depth x m x m
  • E501 Line too long (108 > 79 characters)
105 Take the input through 2 parallel MLP routes, multiply the result, and add a skip-connection at the end.
106 At the skip-connection, reduce the dimension back to output_depth
107 """
108  
  • E501 Line too long (80 > 79 characters)
109 def __init__(self, depth_of_mlp, in_features, out_features, residual=False):
110 super().__init__()
111  
112 self.residual = residual
113  
114 self.mlp1 = MlpBlock(in_features, out_features, depth_of_mlp)
115 self.mlp2 = MlpBlock(in_features, out_features, depth_of_mlp)
116  
117 self.skip = SkipConnection(in_features + out_features, out_features)
118  
119 if self.residual:
120 self.res_x = nn.Linear(in_features, out_features)
121  
122 def forward(self, inputs):
123 mlp1 = self.mlp1(inputs)
124 mlp2 = self.mlp2(inputs)
125  
126 mult = torch.matmul(mlp1, mlp2)
127  
128 out = self.skip(in1=inputs, in2=mult)
129  
130 if self.residual:
131 # Now, changing shapes from [1xdxnxn] to [nxnxd] for Linear() layer
  • E501 Line too long (97 > 79 characters)
132 inputs, out = inputs.permute(3, 2, 1, 0).squeeze(), out.permute(3, 2, 1, 0).squeeze()
133  
134 residual_ = self.res_x(inputs)
135 out = residual_ + out # residual connection
136  
137 # Returning output back to original shape
138 out = out.permute(2, 1, 0).unsqueeze(0)
139  
140 return out
141  
142  
143 class MlpBlock(nn.Module):
144 """
145 Block of MLP layers with activation function after each (1x1 conv layers).
146 """
147  
  • E501 Line too long (98 > 79 characters)
148 def __init__(self, in_features, out_features, depth_of_mlp, activation_fn=nn.functional.relu):
149 super().__init__()
150 self.activation = activation_fn
151 self.convs = nn.ModuleList()
152 for i in range(depth_of_mlp):
  • E501 Line too long (104 > 79 characters)
153 self.convs.append(nn.Conv2d(in_features, out_features, kernel_size=1, padding=0, bias=True))
154 _init_weights(self.convs[-1])
155 in_features = out_features
156  
157 def forward(self, inputs):
158 out = inputs
159 for conv_layer in self.convs:
160 out = self.activation(conv_layer(out))
161  
162 return out
163  
164  
165 class SkipConnection(nn.Module):
166 """
167 Connects the two given inputs with concatenation
168 :param in1: earlier input tensor of shape N x d1 x m x m
169 :param in2: later input tensor of shape N x d2 x m x m
170 :param in_features: d1+d2
171 :param out_features: output num of features
172 :return: Tensor of shape N x output_depth x m x m
173 """
174  
175 def __init__(self, in_features, out_features):
176 super().__init__()
  • E501 Line too long (93 > 79 characters)
177 self.conv = nn.Conv2d(in_features, out_features, kernel_size=1, padding=0, bias=True)
178 _init_weights(self.conv)
179  
180 def forward(self, in1, in2):
181 # in1: N x d1 x m x m
182 # in2: N x d2 x m x m
183 out = torch.cat((in1, in2), dim=1)
184 out = self.conv(out)
185 return out
186  
187  
188 class FullyConnected(nn.Module):
  • E501 Line too long (84 > 79 characters)
189 def __init__(self, in_features, out_features, activation_fn=nn.functional.relu):
190 super().__init__()
191  
192 self.fc = nn.Linear(in_features, out_features)
193 _init_weights(self.fc)
194  
195 self.activation = activation_fn
196  
197 def forward(self, input):
198 out = self.fc(input)
199 if self.activation is not None:
200 out = self.activation(out)
201  
202 return out
203  
204  
205 def diag_offdiag_maxpool(input):
206 N = input.shape[-1]
207  
  • E501 Line too long (82 > 79 characters)
208 max_diag = torch.max(torch.diagonal(input, dim1=-2, dim2=-1), dim=2)[0] # BxS
209  
210 # with torch.no_grad():
211 max_val = torch.max(max_diag)
212 min_val = torch.max(-1 * input)
213 val = torch.abs(torch.add(max_val, min_val))
214  
  • E501 Line too long (80 > 79 characters)
215 min_mat = torch.mul(val, torch.eye(N, device=input.device)).view(1, 1, N, N)
216  
  • E501 Line too long (82 > 79 characters)
217 max_offdiag = torch.max(torch.max(input - min_mat, dim=3)[0], dim=2)[0] # BxS
218  
219 return torch.cat((max_diag, max_offdiag), dim=1) # output Bx2S
220  
221  
222 def _init_weights(layer):
223 """
224 Init weights of the layer
225 :param layer:
226 :return:
227 """
228 nn.init.xavier_uniform_(layer.weight)
229 # nn.init.xavier_normal_(layer.weight)
230 if layer.bias is not None:
231 nn.init.zeros_(layer.bias)
232  
233  
234 class LayerNorm(nn.Module):
235 def __init__(self, d):
236 super().__init__()
  • E501 Line too long (92 > 79 characters)
237 self.a = nn.Parameter(torch.ones(d).unsqueeze(0).unsqueeze(0)) # shape is 1 x 1 x d
  • E501 Line too long (93 > 79 characters)
238 self.b = nn.Parameter(torch.zeros(d).unsqueeze(0).unsqueeze(0)) # shape is 1 x 1 x d
239  
240 def forward(self, x):
241 # x tensor of the shape n x n x d
242 mean = x.mean(dim=(0, 1), keepdim=True)
243 var = x.var(dim=(0, 1), keepdim=True, unbiased=False)
  • E501 Line too long (87 > 79 characters)
244 x = self.a * (x - mean) / torch.sqrt(var + 1e-6) + self.b # shape is n x n x d
245 return x