⬅ models/gsn/graph_filters/GSN_edge_sparse.py source

1 import torch
2 import torch.nn as nn
3 from torch_geometric.utils import degree
4  
5 from models_misc import mlp
6 from utils_graph_learning import central_encoder
7  
  • E302 Expected 2 blank lines, found 1
8 class GSN_edge_sparse(nn.Module):
  • W293 Blank line contains whitespace
9
10 def __init__(self,
11 d_in,
12 d_ef,
13 d_id,
14 d_degree,
15 degree_as_tag,
16 retain_features,
17 id_scope,
18 d_msg,
19 d_up,
20 d_h,
21 seed,
22 activation_name,
23 bn,
24 aggr='add',
25 msg_kind='general',
26 eps=0,
27 train_eps=False,
28 flow='source_to_target',
29 **kwargs):
30  
31 super(GSN_edge_sparse, self).__init__()
  • W293 Blank line contains whitespace
32
33 d_msg = d_in if d_msg is None else d_msg
34  
35 self.flow = flow
36 self.aggr = aggr
37 self.msg_kind = msg_kind
38 self.id_scope = id_scope
  • W293 Blank line contains whitespace
39
40 self.degree_as_tag = degree_as_tag
41 self.retain_features = retain_features
42  
43 if degree_as_tag:
44 d_in = d_in + d_degree if retain_features else d_degree
45  
46 # INPUT_DIMS
47 if msg_kind == 'gin':
  • W293 Blank line contains whitespace
48
49 # dummy variable for self loop edge features
  • E501 Line too long (122 > 79 characters)
50 self.central_node_edge_encoder = central_encoder(kwargs['edge_embedding'], d_ef, extend=kwargs['extend_dims'])
51 d_ef = self.central_node_edge_encoder.d_out
  • W293 Blank line contains whitespace
52
53 if self.id_scope == 'local':
54 # dummy variable for self loop edge counts
  • E501 Line too long (122 > 79 characters)
55 self.central_node_id_encoder = central_encoder(kwargs['id_embedding'], d_id, extend=kwargs['extend_dims'])
56 d_id = self.central_node_id_encoder.d_out
  • W293 Blank line contains whitespace
57
58 msg_input_dim = None
59 self.initial_eps = eps
60 if train_eps:
61 self.eps = torch.nn.Parameter(torch.Tensor([eps]))
62 else:
63 self.register_buffer('eps', torch.Tensor([eps]))
64 self.eps.data.fill_(self.initial_eps)
  • W293 Blank line contains whitespace
65
66 self.msg_fn = None
67 update_input_dim = d_in + d_id + d_ef
  • W293 Blank line contains whitespace
68
69 elif msg_kind == 'general':
  • E501 Line too long (103 > 79 characters)
70 msg_input_dim = 2 * d_in + d_id + d_ef if id_scope == 'local' else 2 * (d_in + d_id) + d_ef
71 # MSG_FUN
  • E501 Line too long (83 > 79 characters)
72 self.msg_fn = mlp(msg_input_dim, d_msg, d_h, seed, activation_name, bn)
73 update_input_dim = d_in + d_msg
  • W293 Blank line contains whitespace
74
75 else:
  • E501 Line too long (97 > 79 characters)
76 raise NotImplementedError('msg kind {} is not currently supported.'.format(msg_kind))
  • W293 Blank line contains whitespace
77
  • E201 Whitespace after '('
  • E501 Line too long (85 > 79 characters)
78 self.update_fn = mlp( update_input_dim, d_up, d_h, seed, activation_name, bn)
79  
80 return
81  
82 def forward(self, x, edge_index, **kwargs):
  • W293 Blank line contains whitespace
83
84 x = x.unsqueeze(-1) if x.dim() == 1 else x
85 identifiers, degrees = kwargs['identifiers'], kwargs['degrees']
86 degrees = degrees.unsqueeze(-1) if degrees.dim() == 1 else degrees
87 if self.degree_as_tag:
  • E501 Line too long (80 > 79 characters)
88 x = torch.cat([x, degrees], -1) if self.retain_features else degrees
  • W293 Blank line contains whitespace
89
90 edge_features = kwargs['edge_features']
  • E501 Line too long (98 > 79 characters)
91 edge_features = edge_features.unsqueeze(-1) if edge_features.dim() == 1 else edge_features
  • W293 Blank line contains whitespace
92
93 n_nodes = x.shape[0]
  • W293 Blank line contains whitespace
94
95 if self.msg_kind == 'gin':
  • W293 Blank line contains whitespace
96
  • E501 Line too long (100 > 79 characters)
97 edge_features_ii, edge_features = self.central_node_edge_encoder(edge_features, n_nodes)
  • W293 Blank line contains whitespace
98
99 if self.id_scope == 'global':
100 identifiers_ii = identifiers
101 else:
  • E501 Line too long (96 > 79 characters)
102 identifiers_ii, identifiers = self.central_node_id_encoder(identifiers, n_nodes)
  • W293 Blank line contains whitespace
103
104 self_msg = torch.cat((x, identifiers_ii, edge_features_ii), -1)
  • W293 Blank line contains whitespace
105
  • E501 Line too long (98 > 79 characters)
106 out = self.update_fn((1 + self.eps) * self_msg + self.propagate(edge_index=edge_index,
  • E501 Line too long (80 > 79 characters)
107 x=x,
  • E501 Line too long (100 > 79 characters)
108 identifiers=identifiers,
  • E501 Line too long (105 > 79 characters)
109 edge_features=edge_features))
  • W293 Blank line contains whitespace
110
111 elif self.msg_kind == 'general':
  • E117 Over-indented
  • E501 Line too long (88 > 79 characters)
112 out = self.update_fn(torch.cat((x, self.propagate(edge_index=edge_index,
  • W291 Trailing whitespace
113 x=x,
  • E501 Line too long (90 > 79 characters)
  • W291 Trailing whitespace
114 identifiers=identifiers,
  • E501 Line too long (101 > 79 characters)
115 edge_features=edge_features)), -1))
116  
117 return out
118  
  • W291 Trailing whitespace
119 def propagate(self, edge_index, x, identifiers, edge_features):
  • W293 Blank line contains whitespace
120
  • W291 Trailing whitespace
121 select = 0 if self.flow == 'target_to_source' else 1
122 aggr_dim = 1 - select
  • W293 Blank line contains whitespace
123
  • E501 Line too long (85 > 79 characters)
124 edge_index_i, edge_index_j = edge_index[select, :], edge_index[1 - select, :]
125 x_i, x_j = x[edge_index_i, :], x[edge_index_j, :]
  • W293 Blank line contains whitespace
126
127 if self.id_scope == 'local':
128 identifiers_ij = identifiers
129 identifiers_i, identifiers_j = None, None
130 else:
131 identifiers_ij = None
  • E501 Line too long (101 > 79 characters)
132 identifiers_i, identifiers_j = identifiers[edge_index_i, :], identifiers[edge_index_j, :]
  • W293 Blank line contains whitespace
133
134 n_nodes = x.shape[0]
  • E501 Line too long (98 > 79 characters)
135 msgs = self.message(x_i, x_j, identifiers_i, identifiers_j, identifiers_ij, edge_features)
  • E501 Line too long (104 > 79 characters)
136 msgs = torch.sparse.FloatTensor(edge_index, msgs, torch.Size([n_nodes, n_nodes, msgs.shape[1]]))
  • W293 Blank line contains whitespace
137
138 if self.aggr == 'add':
139 message = torch.sparse.sum(msgs, aggr_dim).to_dense()
  • W293 Blank line contains whitespace
140
141 elif self.aggr == 'mean':
142 degrees = degree(edge_index[select])
  • E225 Missing whitespace around operator
143 degrees[degrees==0.0] = 1.0
  • F821 Undefined name 'aggr_index'
144 message = torch.sparse.sum(msgs, aggr_index).to_dense()
145 message = message / degrees.unsqueeze(1)
  • W293 Blank line contains whitespace
146
147 else:
  • E501 Line too long (106 > 79 characters)
148 raise NotImplementedError("Aggregation kind {} is not currently supported.".format(self.aggr))
  • W293 Blank line contains whitespace
149
150 return message
  • W293 Blank line contains whitespace
151
  • E501 Line too long (93 > 79 characters)
152 def message(self, x_i, x_j, identifiers_i, identifiers_j, identifiers_ij, edge_features):
  • W293 Blank line contains whitespace
153
154 if self.msg_kind == 'gin':
155 if self.id_scope == 'global':
156 msg_j = torch.cat((x_j, identifiers_j, edge_features), -1)
157 else:
158 msg_j = torch.cat((x_j, identifiers_ij, edge_features), -1)
159  
160 elif self.msg_kind == 'general':
161 if self.id_scope == 'global':
  • E501 Line too long (94 > 79 characters)
162 msg_j = torch.cat((x_i, x_j, identifiers_i, identifiers_j, edge_features), -1)
163 else:
  • E501 Line too long (80 > 79 characters)
164 msg_j = torch.cat((x_i, x_j, identifiers_ij, edge_features), -1)
165 msg_j = self.msg_fn(msg_j)
  • W293 Blank line contains whitespace
166
167 else:
  • E501 Line too long (106 > 79 characters)
168 raise NotImplementedError("Message kind {} is not currently supported.".format(self.msg_kind))
  • W293 Blank line contains whitespace
169
170 return msg_j
  • W293 Blank line contains whitespace
171
172 def __repr__(self):
  • E501 Line too long (109 > 79 characters)
173 return '{}(msg_fn = {}, update_fn = {})'.format(self.__class__.__name__, self.msg_fn, self.update_fn)