⬅ models/gsn/graph_filters/models_misc.py source

1 import torch
2 import torch.nn as nn
3  
4  
5 def choose_activation(activation):
  • E225 Missing whitespace around operator
6 if activation =='elu':
7 return nn.ELU()
  • E225 Missing whitespace around operator
8 elif activation =='relu':
9 return nn.ReLU()
10 elif activation == 'tanh':
11 return nn.Tanh()
12 elif activation == 'identity':
13 return lambda x: x
14 else:
15 raise NotImplementedError
  • W293 Blank line contains whitespace
16
  • W293 Blank line contains whitespace
17
18 class mlp(torch.nn.Module):
19  
20 def __init__(self,
21 in_features,
22 out_features,
23 d_k,
24 seed,
25 activation='elu',
26 batch_norm=False):
27 super(mlp, self).__init__()
  • W293 Blank line contains whitespace
28
29 self.in_features = in_features
30 self.out_features = out_features
31 self.d_k = d_k
32 self.seed = seed
33 self.activation_name = activation
34 self.batch_norm = batch_norm
35  
36 self.fc = []
37 self.bn = []
38  
39 d_in = [in_features]
40 d_k = d_k + [out_features]
41 for i in range(0, len(d_k)):
42 self.fc.append(nn.Linear(d_in[i], d_k[i], bias=True))
43 d_in = d_in + [d_k[i]]
  • E225 Missing whitespace around operator
44 if self.batch_norm and i!=len(d_k)-1:
45 self.bn.append(nn.BatchNorm1d((d_k[i])))
46  
47 self.fc = nn.ModuleList(self.fc)
48 self.bn = nn.ModuleList(self.bn)
49 self.activation = choose_activation(activation)
50  
51  
  • E303 Too many blank lines (2)
52 def forward(self, x):
53 for i in range(0, len(self.fc)-1):
54 if self.batch_norm:
55 x = self.activation(self.bn[i](self.fc[i](x)))
56 else:
57 x = self.activation(self.fc[i](x))
58 x = self.fc[-1](x)
59 return x