⬅ models/gsn/graph_filters/MPNN_edge_sparse.py source

1 import torch
2 import torch.nn as nn
  • F401 'torch_geometric.utils.is_undirected' imported but unused
3 from torch_geometric.utils import degree, is_undirected
4  
5 from models_misc import mlp
6 from utils_graph_learning import central_encoder
7  
  • E302 Expected 2 blank lines, found 1
8 class MPNN_edge_sparse(nn.Module):
  • W293 Blank line contains whitespace
9
10 def __init__(self,
11 d_in,
12 d_ef,
13 d_degree,
14 degree_as_tag,
15 retain_features,
16 d_msg,
17 d_up,
18 d_h,
19 seed,
20 activation_name,
21 bn,
22 aggr='add',
23 msg_kind='general',
24 eps=0,
25 train_eps=False,
26 flow='source_to_target',
27 **kwargs):
28  
29 super(MPNN_edge_sparse, self).__init__()
  • W293 Blank line contains whitespace
30
31  
  • E303 Too many blank lines (2)
32 d_msg = d_in if d_msg is None else d_msg
33  
34 self.flow = flow
35 self.aggr = aggr
36 self.msg_kind = msg_kind
  • W293 Blank line contains whitespace
37
38 self.degree_as_tag = degree_as_tag
39 self.retain_features = retain_features
40  
41 if degree_as_tag:
42 d_in = d_in + d_degree if retain_features else d_degree
  • W293 Blank line contains whitespace
43
44 # INPUT_DIMS
45 if msg_kind == 'gin':
  • W293 Blank line contains whitespace
46
47 # dummy variable for self loop edge features
  • E501 Line too long (122 > 79 characters)
48 self.central_node_edge_encoder = central_encoder(kwargs['edge_embedding'], d_ef, extend=kwargs['extend_dims'])
49 d_ef = self.central_node_edge_encoder.d_out
  • W293 Blank line contains whitespace
50
51 msg_input_dim = None
52 self.initial_eps = eps
53 if train_eps:
54 self.eps = torch.nn.Parameter(torch.Tensor([eps]))
55 else:
56 self.register_buffer('eps', torch.Tensor([eps]))
57 self.eps.data.fill_(self.initial_eps)
58 self.msg_fn = None
59 update_input_dim = d_in + d_ef
60  
61 elif msg_kind == 'general':
62 msg_input_dim = 2 * d_in + d_ef
63 self.msg_fn = mlp(
64 msg_input_dim,
65 d_msg,
66 d_h,
67 seed,
68 activation_name,
69 bn)
70 update_input_dim = d_in + d_msg
  • W293 Blank line contains whitespace
71
72 self.update_fn = mlp(
73 update_input_dim,
74 d_up,
75 d_h,
76 seed,
77 activation_name,
78 bn)
79  
80 return
81  
82 def forward(self, x, edge_index, **kwargs):
  • W293 Blank line contains whitespace
83
84 # prepare input features
85 x = x.unsqueeze(-1) if x.dim() == 1 else x
86  
87 degrees = kwargs['degrees']
88 degrees = degrees.unsqueeze(-1) if degrees.dim() == 1 else degrees
89 if self.degree_as_tag:
  • E501 Line too long (80 > 79 characters)
90 x = torch.cat([x, degrees], -1) if self.retain_features else degrees
  • W293 Blank line contains whitespace
91
92 edge_features = kwargs['edge_features']
  • E501 Line too long (98 > 79 characters)
93 edge_features = edge_features.unsqueeze(-1) if edge_features.dim() == 1 else edge_features
  • W293 Blank line contains whitespace
94
95 n_nodes = x.shape[0]
  • W293 Blank line contains whitespace
96
97 if self.msg_kind == 'gin':
  • W293 Blank line contains whitespace
98
  • E501 Line too long (100 > 79 characters)
99 edge_features_ii, edge_features = self.central_node_edge_encoder(edge_features, n_nodes)
100 self_msg = torch.cat((x, edge_features_ii), -1)
  • W293 Blank line contains whitespace
101
  • W291 Trailing whitespace
102 out = self.update_fn((1 + self.eps) * self_msg +
  • E501 Line too long (105 > 79 characters)
103 self.propagate(edge_index=edge_index, x=x, edge_features=edge_features))
  • W293 Blank line contains whitespace
104
105 elif self.msg_kind == 'general':
  • E117 Over-indented
  • E501 Line too long (130 > 79 characters)
106 out = self.update_fn(torch.cat((x, self.propagate(edge_index=edge_index, x=x, edge_features=edge_features)), -1))
107  
108 return out
  • W293 Blank line contains whitespace
109
110 def propagate(self, edge_index, x, edge_features):
  • W293 Blank line contains whitespace
111
  • W291 Trailing whitespace
112 select = 0 if self.flow == 'target_to_source' else 1
113 aggr_dim = 1 - select
114 n_nodes = x.shape[0]
  • W293 Blank line contains whitespace
115
  • E501 Line too long (85 > 79 characters)
116 edge_index_i, edge_index_j = edge_index[select, :], edge_index[1 - select, :]
117 x_i, x_j = x[edge_index_i, :], x[edge_index_j, :]
  • W293 Blank line contains whitespace
118
119 msgs = self.message(x_i, x_j, edge_features)
  • E501 Line too long (104 > 79 characters)
120 msgs = torch.sparse.FloatTensor(edge_index, msgs, torch.Size([n_nodes, n_nodes, msgs.shape[1]]))
  • W293 Blank line contains whitespace
121
122 if self.aggr == 'add':
123 message = torch.sparse.sum(msgs, aggr_dim).to_dense()
  • W293 Blank line contains whitespace
124
125 elif self.aggr == 'mean':
126 degrees = degree(edge_index[select])
  • E225 Missing whitespace around operator
127 degrees[degrees==0.0] = 1.0
  • F821 Undefined name 'aggr_index'
128 message = torch.sparse.sum(msgs, aggr_index).to_dense()
129 message = message / degrees.unsqueeze(1)
  • W293 Blank line contains whitespace
130
131 else:
  • E501 Line too long (106 > 79 characters)
132 raise NotImplementedError("Aggregation kind {} is not currently supported.".format(self.aggr))
  • W293 Blank line contains whitespace
133
134 return message
135  
136 def message(self, x_i, x_j, edge_features):
  • W293 Blank line contains whitespace
137
138 if self.msg_kind == 'gin':
139 msg_j = torch.cat((x_j, edge_features), -1)
140 elif self.msg_kind == 'general':
141 msg_j = torch.cat((x_i, x_j, edge_features), -1)
142 msg_j = self.msg_fn(msg_j)
143 else:
  • E501 Line too long (106 > 79 characters)
144 raise NotImplementedError("Message kind {} is not currently supported.".format(self.msg_kind))
  • W293 Blank line contains whitespace
145
146 return msg_j
  • W293 Blank line contains whitespace
147
148 def __repr__(self):
  • E501 Line too long (109 > 79 characters)
149 return '{}(msg_fn = {}, update_fn = {})'.format(self.__class__.__name__, self.msg_fn, self.update_fn)