⬅ models/gsn/graph_filters/MPNN_sparse.py source

1 import torch
2 import torch.nn as nn
3 from torch_geometric.utils import degree
4  
5 from models_misc import mlp
6  
  • E302 Expected 2 blank lines, found 1
7 class MPNN_sparse(nn.Module):
  • W293 Blank line contains whitespace
8
9 def __init__(self,
10 d_in,
11 d_degree,
12 degree_as_tag,
13 retain_features,
14 d_msg,
15 d_up,
16 d_h,
17 seed,
18 activation_name,
19 bn,
20 aggr='add',
21 msg_kind='general',
22 eps=0,
23 train_eps=False,
24 flow='source_to_target',
25 **kwargs):
26  
27 super(MPNN_sparse, self).__init__()
  • W293 Blank line contains whitespace
28
29 d_msg = d_in if d_msg is None else d_msg
30  
31 self.flow = flow
32 self.aggr = aggr
33 self.msg_kind = msg_kind
  • W293 Blank line contains whitespace
34
35 self.degree_as_tag = degree_as_tag
36 self.retain_features = retain_features
  • W293 Blank line contains whitespace
37
38 if degree_as_tag:
39 d_in = d_in + d_degree if retain_features else d_degree
40  
41 if msg_kind == 'gin':
42 msg_input_dim = None
43 self.initial_eps = eps
44 if train_eps:
45 self.eps = torch.nn.Parameter(torch.Tensor([eps]))
46 else:
47 self.register_buffer('eps', torch.Tensor([eps]))
48 self.eps.data.fill_(self.initial_eps)
49 self.msg_fn = None
50 update_input_dim = d_in
  • W293 Blank line contains whitespace
51
52 elif msg_kind == 'general':
53 msg_input_dim = 2 * d_in
54 self.msg_fn = mlp(
55 msg_input_dim,
56 d_msg,
57 d_h,
58 seed,
59 activation_name,
60 bn)
61 update_input_dim = d_in + d_msg
62  
63 self.update_fn = mlp(
64 update_input_dim,
65 d_up,
66 d_h,
67 seed,
68 activation_name,
69 bn)
70  
71 return
  • W293 Blank line contains whitespace
72
73 def forward(self, x, edge_index, **kwargs):
74  
75 # prepare input features
76 x = x.unsqueeze(-1) if x.dim() == 1 else x
  • W293 Blank line contains whitespace
77
78 degrees = kwargs['degrees']
79 degrees = degrees.unsqueeze(-1) if degrees.dim() == 1 else degrees
80 if self.degree_as_tag:
  • E501 Line too long (80 > 79 characters)
81 x = torch.cat([x, degrees], -1) if self.retain_features else degrees
82  
83 if self.msg_kind == 'gin':
84 self_msg = x
  • E501 Line too long (104 > 79 characters)
  • W291 Trailing whitespace
85 out = self.update_fn((1 + self.eps) * self_msg + self.propagate(edge_index=edge_index, x=x))
86 elif self.msg_kind == 'general':
  • E501 Line too long (96 > 79 characters)
87 out = self.update_fn(torch.cat((x, self.propagate(edge_index=edge_index, x=x)), -1))
88  
89 return out
  • W293 Blank line contains whitespace
90
91 def propagate(self, edge_index, x):
  • W293 Blank line contains whitespace
92
  • W291 Trailing whitespace
93 select = 0 if self.flow == 'target_to_source' else 1
94 aggr_dim = 1 - select
95 n_nodes = x.shape[0]
  • W293 Blank line contains whitespace
96
  • E501 Line too long (85 > 79 characters)
97 edge_index_i, edge_index_j = edge_index[select, :], edge_index[1 - select, :]
98 x_i, x_j = x[edge_index_i, :], x[edge_index_j, :]
  • W293 Blank line contains whitespace
99
100 msgs = self.message(x_i, x_j)
  • E501 Line too long (104 > 79 characters)
101 msgs = torch.sparse.FloatTensor(edge_index, msgs, torch.Size([n_nodes, n_nodes, msgs.shape[1]]))
  • W293 Blank line contains whitespace
102
103 if self.aggr == 'add':
104 message = torch.sparse.sum(msgs, aggr_dim).to_dense()
  • W293 Blank line contains whitespace
105
106 elif self.aggr == 'mean':
107 degrees = degree(edge_index[select])
  • E225 Missing whitespace around operator
108 degrees[degrees==0.0] = 1.0
  • F821 Undefined name 'aggr_index'
109 message = torch.sparse.sum(msgs, aggr_index).to_dense()
110 message = message / degrees.unsqueeze(1)
  • W293 Blank line contains whitespace
111
112 else:
  • E501 Line too long (106 > 79 characters)
113 raise NotImplementedError("Aggregation kind {} is not currently supported.".format(self.aggr))
  • W293 Blank line contains whitespace
114
115 return message
  • W293 Blank line contains whitespace
116
117 def message(self, x_i, x_j):
  • W293 Blank line contains whitespace
118
119 if self.msg_kind == 'gin':
120 msg_j = x_j
121 elif self.msg_kind == 'general':
122 msg_j = self.msg_fn(torch.cat((x_i, x_j), -1))
123 else:
  • E501 Line too long (106 > 79 characters)
124 raise NotImplementedError("Message kind {} is not currently supported.".format(self.msg_kind))
125 return msg_j
  • W293 Blank line contains whitespace
126
127 def __repr__(self):
  • E501 Line too long (109 > 79 characters)
128 return '{}(msg_fn = {}, update_fn = {})'.format(self.__class__.__name__, self.msg_fn, self.update_fn)