1 import torch
            
            2 import torch.nn as nn
            
            3  
            
            4  
            
            5 def choose_activation(activation):
            
            
               
               
                  - 
                     
                        E225
                     
                     Missing whitespace around operator
 
               
               
6     if activation =='elu':
             
            7         return nn.ELU()
            
            
               
               
                  - 
                     
                        E225
                     
                     Missing whitespace around operator
 
               
               
8     elif activation =='relu':
             
            9         return nn.ReLU()
            
            10     elif activation == 'tanh':
            
            11         return nn.Tanh()
            
            12     elif activation == 'identity':
            
            13         return lambda x: x
            
            14     else:
            
            15         raise NotImplementedError
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
16         
             
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
17             
             
            18 class mlp(torch.nn.Module):
            
            19  
            
            20     def __init__(self,
            
            21                  in_features,
            
            22                  out_features,
            
            23                  d_k,
            
            24                  seed,
            
            25                  activation='elu',
            
            26                  batch_norm=False):
            
            27         super(mlp, self).__init__()
            
            
               
               
                  - 
                     
                        W293
                     
                     Blank line contains whitespace
 
               
               
28         
             
            29         self.in_features = in_features
            
            30         self.out_features = out_features
            
            31         self.d_k = d_k
            
            32         self.seed = seed
            
            33         self.activation_name = activation
            
            34         self.batch_norm = batch_norm
            
            35  
            
            36         self.fc = []
            
            37         self.bn = []
            
            38  
            
            39         d_in = [in_features]
            
            40         d_k = d_k + [out_features]
            
            41         for i in range(0, len(d_k)):
            
            42             self.fc.append(nn.Linear(d_in[i], d_k[i], bias=True))
            
            43             d_in = d_in + [d_k[i]]
            
            
               
               
                  - 
                     
                        E225
                     
                     Missing whitespace around operator
 
               
               
44             if self.batch_norm and i!=len(d_k)-1:
             
            45                 self.bn.append(nn.BatchNorm1d((d_k[i])))
            
            46  
            
            47         self.fc = nn.ModuleList(self.fc)
            
            48         self.bn = nn.ModuleList(self.bn)
            
            49         self.activation = choose_activation(activation)
            
            50  
            
            51  
            
            
               
               
                  - 
                     
                        E303
                     
                     Too many blank lines (2)
 
               
               
52     def forward(self, x):
             
            53         for i in range(0, len(self.fc)-1):
            
            54             if self.batch_norm:
            
            55                 x = self.activation(self.bn[i](self.fc[i](x)))
            
            56             else:
            
            57                 x = self.activation(self.fc[i](x))
            
            58         x = self.fc[-1](x)
            
            59         return x