Line too long (86 > 79 characters):
10 CODE adapted from https://github.com/hadarser/ProvablyPowerfulGraphNetworks_torch/Line too long (104 > 79 characters):
15 def __init__(self, gnn_type, num_tasks, three_wl_in_dim, n_layers=3, depth_of_mlp=2, hidden_dim=128,Line too long (87 > 79 characters):
23 original_features_num = self.in_dim_node + 1 # Number of features of the inputLine too long (107 > 79 characters):
29 mlp_block = RegularBlock(depth_of_mlp, last_layer_features, next_layer_features, self.residual)Line too long (114 > 79 characters):
35 # each block's output will be pooled (thus have 2*output_features), and pass through a fully connectedLine too long (83 > 79 characters):
36 fc = FullyConnected(2 * output_features, num_tasks, activation_fn=None)Line too long (102 > 79 characters):
48 # def __init__(self, gnn_type, num_tasks, feature_dim, n_layers=3, depth_of_mlp=2, hidden_dim=128,Line too long (80 > 79 characters):
57 # block_features = [hidden_dim] * n_layers # L here is the block numberLine too long (110 > 79 characters):
58 # original_features_num = self.in_dim_node + self.num_bond_type + 1 # Number of features of the inputLine too long (109 > 79 characters):
64 # mlp_block = RegularBlock(depth_of_mlp, last_layer_features, next_layer_features, self.residual)Line too long (116 > 79 characters):
70 # # each block's output will be pooled (thus have 2*output_features), and pass through a fully connectedLine too long (85 > 79 characters):
71 # fc = FullyConnected(2 * output_features, num_tasks, activation_fn=None)Line too long (113 > 79 characters):
88 list_FC_layers = [nn.Linear(input_dim // 2 ** l, input_dim // 2 ** (l + 1), bias=True) for l in range(L)]Line too long (84 > 79 characters):
89 list_FC_layers.append(nn.Linear(input_dim // 2 ** L, output_dim, bias=True))Line too long (108 > 79 characters):
105 Take the input through 2 parallel MLP routes, multiply the result, and add a skip-connection at the end.Line too long (80 > 79 characters):
109 def __init__(self, depth_of_mlp, in_features, out_features, residual=False):Line too long (97 > 79 characters):
132 inputs, out = inputs.permute(3, 2, 1, 0).squeeze(), out.permute(3, 2, 1, 0).squeeze()Line too long (98 > 79 characters):
148 def __init__(self, in_features, out_features, depth_of_mlp, activation_fn=nn.functional.relu):Line too long (104 > 79 characters):
153 self.convs.append(nn.Conv2d(in_features, out_features, kernel_size=1, padding=0, bias=True))Line too long (93 > 79 characters):
177 self.conv = nn.Conv2d(in_features, out_features, kernel_size=1, padding=0, bias=True)Line too long (84 > 79 characters):
189 def __init__(self, in_features, out_features, activation_fn=nn.functional.relu):Line too long (82 > 79 characters):
208 max_diag = torch.max(torch.diagonal(input, dim1=-2, dim2=-1), dim=2)[0] # BxSLine too long (80 > 79 characters):
215 min_mat = torch.mul(val, torch.eye(N, device=input.device)).view(1, 1, N, N)Line too long (82 > 79 characters):
217 max_offdiag = torch.max(torch.max(input - min_mat, dim=3)[0], dim=2)[0] # BxSLine too long (92 > 79 characters):
237 self.a = nn.Parameter(torch.ones(d).unsqueeze(0).unsqueeze(0)) # shape is 1 x 1 x dLine too long (93 > 79 characters):
238 self.b = nn.Parameter(torch.zeros(d).unsqueeze(0).unsqueeze(0)) # shape is 1 x 1 x dLine too long (87 > 79 characters):
244 x = self.a * (x - mean) / torch.sqrt(var + 1e-6) + self.b # shape is n x n x d