1 import torch
2 import torch.nn as nn
-
F401
'torch_geometric.utils.is_undirected' imported but unused
3 from torch_geometric.utils import degree, is_undirected
4
5 from .models_misc import mlp
6 from .utils_graph_learning import central_encoder
7
8
9 class GSN_sparse(nn.Module):
10 def __init__(self,
11 d_in,
12 d_id,
13 d_degree,
14 degree_as_tag,
15 retain_features,
16 id_scope,
17 d_msg,
18 d_up,
19 d_h,
20 seed,
21 activation_name,
22 bn,
23 aggr='add',
24 msg_kind='general',
25 eps=0,
26 train_eps=False,
27 flow='source_to_target',
28 **kwargs):
29
30 super(GSN_sparse, self).__init__()
31
32 d_msg = d_in if d_msg is None else d_msg
33
34 self.flow = flow
35 self.aggr = aggr
36 self.msg_kind = msg_kind
37 self.id_scope = id_scope
38
39 self.degree_as_tag = degree_as_tag
40 self.retain_features = retain_features
41
42 if degree_as_tag:
43 d_in = d_in + d_degree if retain_features else d_degree
44
-
E501
Line too long (99 > 79 characters)
45 # different GSN variants: extended gin and general formulation (see supplementary material)
46 if self.msg_kind == 'gin':
47
48 if self.id_scope == 'local':
49 # dummy variable for central node
-
E501
Line too long (92 > 79 characters)
50 self.central_node_id_encoder = central_encoder(kwargs['id_embedding'], d_id,
-
E501
Line too long (92 > 79 characters)
51 extend=kwargs['extend_dims'])
52 d_id = self.central_node_id_encoder.d_out
53
54 msg_input_dim = None
55 update_input_dim = d_in + d_id
56
57 self.initial_eps = eps
58 if train_eps:
59 self.eps = torch.nn.Parameter(torch.Tensor([eps]))
60 else:
61 self.register_buffer('eps', torch.Tensor([eps]))
62 self.eps.data.fill_(self.initial_eps)
63
64 self.msg_fn = None
65
66 elif self.msg_kind == 'general':
67
-
E501
Line too long (89 > 79 characters)
68 msg_input_dim = 2 * d_in + d_id if id_scope == 'local' else 2 * (d_in + d_id)
69
70 self.msg_fn = mlp(
71 msg_input_dim,
72 d_msg,
73 d_h,
74 seed,
75 activation_name,
76 bn)
77
78 update_input_dim = d_in + d_msg
79
80 else:
-
E501
Line too long (97 > 79 characters)
81 raise NotImplementedError('msg kind {} is not currently supported.'.format(msg_kind))
82
83 self.update_fn = mlp(
84 update_input_dim,
85 d_up,
86 d_h,
87 seed,
88 activation_name,
89 bn)
90
91 return
92
93 def forward(self, x, edge_index, **kwargs):
94
95 # prepare input features
96 x = x.unsqueeze(-1) if x.dim() == 1 else x
97 identifiers, degrees = kwargs['identifiers'], kwargs['degrees']
98 degrees = degrees.unsqueeze(-1) if degrees.dim() == 1 else degrees
99 if self.degree_as_tag:
-
E501
Line too long (80 > 79 characters)
100 x = torch.cat([x, degrees], -1) if self.retain_features else degrees
101
102 n_nodes = x.shape[0]
103 if self.msg_kind == 'gin':
104 if self.id_scope == 'global':
105 identifiers_ii = identifiers
106 else:
-
E501
Line too long (96 > 79 characters)
107 identifiers_ii, identifiers = self.central_node_id_encoder(identifiers, n_nodes)
108 self_msg = torch.cat((x, identifiers_ii), -1)
109
110 out = self.update_fn((1 + self.eps) * self_msg +
-
E501
Line too long (101 > 79 characters)
111 self.propagate(edge_index=edge_index, x=x, identifiers=identifiers))
112
113 elif self.msg_kind == 'general':
114 out = self.update_fn(
-
E501
Line too long (104 > 79 characters)
115 torch.cat((x, self.propagate(edge_index=edge_index, x=x, identifiers=identifiers)), -1))
116
117 else:
-
E501
Line too long (106 > 79 characters)
118 raise NotImplementedError("Message kind {} is not currently supported.".format(self.msg_kind))
119
120 return out
121
122 def propagate(self, edge_index, x, identifiers):
123
124 select = 0 if self.flow == 'target_to_source' else 1
125 aggr_dim = 1 - select
126
-
E501
Line too long (85 > 79 characters)
127 edge_index_i, edge_index_j = edge_index[select, :], edge_index[1 - select, :]
128 x_i, x_j = x[edge_index_i, :], x[edge_index_j, :]
129
130 if self.id_scope == 'local':
131 identifiers_ij = identifiers
132 identifiers_i, identifiers_j = None, None
133 else:
134 identifiers_ij = None
-
E501
Line too long (101 > 79 characters)
135 identifiers_i, identifiers_j = identifiers[edge_index_i, :], identifiers[edge_index_j, :]
136
137 n_nodes = x.shape[0]
-
E501
Line too long (83 > 79 characters)
138 msgs = self.message(x_i, x_j, identifiers_i, identifiers_j, identifiers_ij)
-
E501
Line too long (104 > 79 characters)
139 msgs = torch.sparse.FloatTensor(edge_index, msgs, torch.Size([n_nodes, n_nodes, msgs.shape[1]]))
140
141 if self.aggr == 'add':
142 message = torch.sparse.sum(msgs, aggr_dim).to_dense()
143
144 elif self.aggr == 'mean':
145 degrees = degree(edge_index[select])
146 degrees[degrees == 0.0] = 1.0
-
F821
Undefined name 'aggr_index'
147 message = torch.sparse.sum(msgs, aggr_index).to_dense()
148 message = message / degrees.unsqueeze(1)
149
150 else:
-
E501
Line too long (106 > 79 characters)
151 raise NotImplementedError("Aggregation kind {} is not currently supported.".format(self.aggr))
152
153 return message
154
155 def message(self, x_i, x_j, identifiers_i, identifiers_j, identifiers_ij):
156
157 if self.msg_kind == 'gin':
158 if self.id_scope == 'global':
159 msg_j = torch.cat((x_j, identifiers_j), -1)
160 else:
161 msg_j = torch.cat((x_j, identifiers_ij), -1)
162
163 elif self.msg_kind == 'general':
164 if self.id_scope == 'global':
165 msg_j = torch.cat((x_i, x_j, identifiers_i, identifiers_j), -1)
166 else:
167 msg_j = torch.cat((x_i, x_j, identifiers_ij), -1)
168 msg_j = self.msg_fn(msg_j)
169
170 else:
-
E501
Line too long (106 > 79 characters)
171 raise NotImplementedError("Message kind {} is not currently supported.".format(self.msg_kind))
172
173 return msg_j
174
175 def __repr__(self):
-
E501
Line too long (109 > 79 characters)
176 return '{}(msg_fn = {}, update_fn = {})'.format(self.__class__.__name__, self.msg_fn, self.update_fn)