-
F401
'copy' imported but unused
1 import copy
2
3 import torch
4 from torch_geometric.nn import MessagePassing
5 import torch.nn.functional as F
6 from ogb.graphproppred.mol_encoder import BondEncoder
7 from torch_geometric.utils import degree
8
9
10 from typing import Optional
11 from torch_geometric.typing import OptTensor
12 from torch.nn import Parameter
-
F401
'torch_geometric.nn.inits.zeros' imported but unused
13 from torch_geometric.nn.inits import zeros
14 from torch_geometric.utils import get_laplacian
-
E501
Line too long (89 > 79 characters)
15 from torch_geometric.utils import remove_self_loops, add_self_loops, segregate_self_loops
16 from torch_sparse import coalesce
17
-
F401
'pdb' imported but unused
18 import pdb
19
20
21 class GINConvNew(MessagePassing):
22 def __init__(self, emb_dim, dataset_group):
23 '''
24 emb_dim (int): node embedding dimensionality
25 '''
26
-
E251
Unexpected spaces around keyword / parameter equals (in 2 places)
27 super(GINConvNew, self).__init__(aggr = "add")
28
-
E501
Line too long (112 > 79 characters)
29 self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.BatchNorm1d(2 * emb_dim),
-
E501
Line too long (94 > 79 characters)
30 torch.nn.ReLU(), torch.nn.Linear(2 * emb_dim, emb_dim))
31 self.eps = torch.nn.Parameter(torch.Tensor([0]))
32
-
E203
Whitespace before ':'
33 if dataset_group == 'mol' :
-
E251
Unexpected spaces around keyword / parameter equals (in 2 places)
34 self.edge_encoder = BondEncoder(emb_dim = emb_dim)
-
E203
Whitespace before ':'
35 else :
36 self.edge_encoder = torch.nn.Linear(7, emb_dim)
37
38 def forward(self, x, edge_index, edge_attr):
39 edge_embedding = self.edge_encoder(edge_attr)
40
-
E225
Missing whitespace around operator
-
E501
Line too long (101 > 79 characters)
41 out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
42
43 return out
44
45 def message(self, x_j, edge_attr):
46 return F.relu(x_j + edge_attr)
47
48 def update(self, aggr_out):
49 return aggr_out
50
51
52 class GCNConvNew(MessagePassing):
53 def __init__(self, emb_dim, dataset_group):
54 super(GCNConvNew, self).__init__(aggr='add')
55
56 self.linear = torch.nn.Linear(emb_dim, emb_dim)
57 self.root_emb = torch.nn.Embedding(1, emb_dim)
58
-
E203
Whitespace before ':'
59 if dataset_group == 'mol' :
-
E251
Unexpected spaces around keyword / parameter equals (in 2 places)
60 self.edge_encoder = BondEncoder(emb_dim = emb_dim)
-
E203
Whitespace before ':'
61 else :
62 self.edge_encoder = torch.nn.Linear(7, emb_dim)
63
64 def forward(self, x, edge_index, edge_attr):
65 x = self.linear(x)
66 edge_embedding = self.edge_encoder(edge_attr)
67
68 row, col = edge_index
69
-
E501
Line too long (84 > 79 characters)
70 # edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device)
71 deg = degree(row, x.size(0), dtype=x.dtype) + 1
72 deg_inv_sqrt = deg.pow(-0.5)
73 deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
74
75 norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
76
-
E501
Line too long (93 > 79 characters)
77 return self.propagate(edge_index, x=x, edge_attr=edge_embedding, norm=norm) + F.relu(
78 x + self.root_emb.weight) * 1. / deg.view(-1, 1)
79
80 def message(self, x_j, edge_attr, norm):
81 return norm.view(-1, 1) * F.relu(x_j + edge_attr)
82
83 def update(self, aggr_out):
84 return aggr_out
85
86
87 class ChebConvNew(MessagePassing):
88 def __init__(self, emb_dim: int, K: int, dataset_group: str,
89 normalization: Optional[str] = 'sym', bias: bool = True,
90 **kwargs):
91 kwargs.setdefault('aggr', 'add')
92 super(ChebConvNew, self).__init__(**kwargs)
93
94 assert K > 0
95 assert normalization in [None, 'sym', 'rw'], 'Invalid normalization'
96
97 self.root_emb = torch.nn.Embedding(1, emb_dim)
-
E203
Whitespace before ':'
98 if dataset_group == 'mol' :
-
E251
Unexpected spaces around keyword / parameter equals (in 2 places)
99 self.edge_encoder = BondEncoder(emb_dim = emb_dim)
-
E203
Whitespace before ':'
100 else :
101 self.edge_encoder = torch.nn.Linear(7, emb_dim)
102
103 self.emb_dim = emb_dim
104 self.normalization = normalization
105 self.lins = torch.nn.ModuleList([
106 torch.nn.Linear(emb_dim, emb_dim, bias=False) for _ in range(K)
107 ])
108
109 if bias:
110 self.bias = Parameter(torch.Tensor(emb_dim))
111 else:
112 self.register_parameter('bias', None)
113
114 def __norm__(self, edge_index, num_nodes: Optional[int],
115 edge_weight: OptTensor, normalization: Optional[str],
116 lambda_max, dtype: Optional[int] = None,
117 batch: OptTensor = None):
118
119 edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
120 edge_index, edge_weight = get_laplacian(edge_index, edge_weight,
121 normalization, dtype,
122 num_nodes)
123
124 if batch is not None and lambda_max.numel() > 1:
125 lambda_max = lambda_max[batch[edge_index[0]]]
126
127 edge_weight = (2.0 * edge_weight) / lambda_max
128 edge_weight.masked_fill_(edge_weight == float('inf'), 0)
129
130 edge_index, edge_weight = add_self_loops(edge_index, edge_weight,
131 fill_value=-1.,
132 num_nodes=num_nodes)
133 assert edge_weight is not None
134 return edge_index, edge_weight
135
136 def forward(self, x, edge_index, edge_attr: OptTensor = None,
137 batch: OptTensor = None, lambda_max: OptTensor = None):
138 """"""
139 edge_embedding = self.edge_encoder(edge_attr)
140
141 if self.normalization != 'sym' and lambda_max is None:
142 raise ValueError('You need to pass `lambda_max` to `forward() in`'
143 'case the normalization is non-symmetric.')
144
145 if lambda_max is None:
146 lambda_max = torch.tensor(2.0, dtype=x.dtype, device=x.device)
147 if not isinstance(lambda_max, torch.Tensor):
148 lambda_max = torch.tensor(lambda_max, dtype=x.dtype,
149 device=x.device)
150 assert lambda_max is not None
151
152 edge_index, norm = self.__norm__(edge_index, x.size(self.node_dim),
153 None, self.normalization,
154 lambda_max, dtype=x.dtype,
155 batch=batch)
156
-
E501
Line too long (83 > 79 characters)
157 # edge_index, norm = coalesce(edge_index, norm, m=x.shape[0], n=x.shape[0])
-
E501
Line too long (93 > 79 characters)
158 edge_index, norm, loop_edge_index, loop_norm = segregate_self_loops(edge_index, norm)
-
E501
Line too long (101 > 79 characters)
159 loop_edge_index, loop_norm = coalesce(loop_edge_index, loop_norm, m=x.shape[0], n=x.shape[0])
160
161 Tx_0 = x
162 Tx_1 = x # Dummy.
163 out = self.lins[0](Tx_0)
164
165 # propagate_type: (x: Tensor, norm: Tensor)
166 if len(self.lins) > 1:
-
E501
Line too long (100 > 79 characters)
-
E225
Missing whitespace around operator
167 Tx_1 = self.propagate(edge_index, x=x, edge_attr=edge_embedding, norm=norm, size=None)+\
-
E231
Missing whitespace after ','
168 loop_norm.view(-1,1)*F.relu(x + self.root_emb.weight)
169 out = out + self.lins[1](Tx_1)
170
171 for lin in self.lins[2:]:
-
E501
Line too long (103 > 79 characters)
-
E225
Missing whitespace around operator
172 Tx_2 = self.propagate(edge_index, x=Tx_1, edge_attr=edge_embedding, norm=norm, size=None)+\
-
E231
Missing whitespace after ','
173 loop_norm.view(-1,1)*F.relu(Tx_1 + self.root_emb.weight)
174 Tx_2 = 2. * Tx_2 - Tx_0
175 out = out + lin.forward(Tx_2)
176 Tx_0, Tx_1 = Tx_1, Tx_2
177
178 if self.bias is not None:
179 out += self.bias
180
181 return out
182
183 def message(self, x_j, edge_attr, norm):
184 return norm.view(-1, 1) * F.relu(x_j+edge_attr)