⬅ models/conv.py source

  • F401 'copy' imported but unused
1 import copy
2  
3 import torch
4 from torch_geometric.nn import MessagePassing
5 import torch.nn.functional as F
6 from ogb.graphproppred.mol_encoder import BondEncoder
7 from torch_geometric.utils import degree
8  
9  
10 from typing import Optional
11 from torch_geometric.typing import OptTensor
12 from torch.nn import Parameter
  • F401 'torch_geometric.nn.inits.zeros' imported but unused
13 from torch_geometric.nn.inits import zeros
14 from torch_geometric.utils import get_laplacian
  • E501 Line too long (89 > 79 characters)
15 from torch_geometric.utils import remove_self_loops, add_self_loops, segregate_self_loops
16 from torch_sparse import coalesce
17  
  • F401 'pdb' imported but unused
18 import pdb
19  
20  
21 class GINConvNew(MessagePassing):
22 def __init__(self, emb_dim, dataset_group):
23 '''
24 emb_dim (int): node embedding dimensionality
25 '''
26  
  • E251 Unexpected spaces around keyword / parameter equals (in 2 places)
27 super(GINConvNew, self).__init__(aggr = "add")
28  
  • E501 Line too long (112 > 79 characters)
29 self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.BatchNorm1d(2 * emb_dim),
  • E501 Line too long (94 > 79 characters)
30 torch.nn.ReLU(), torch.nn.Linear(2 * emb_dim, emb_dim))
31 self.eps = torch.nn.Parameter(torch.Tensor([0]))
32  
  • E203 Whitespace before ':'
33 if dataset_group == 'mol' :
  • E251 Unexpected spaces around keyword / parameter equals (in 2 places)
34 self.edge_encoder = BondEncoder(emb_dim = emb_dim)
  • E203 Whitespace before ':'
35 else :
36 self.edge_encoder = torch.nn.Linear(7, emb_dim)
37  
38 def forward(self, x, edge_index, edge_attr):
39 edge_embedding = self.edge_encoder(edge_attr)
40  
  • E225 Missing whitespace around operator
  • E501 Line too long (101 > 79 characters)
41 out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
42  
43 return out
44  
45 def message(self, x_j, edge_attr):
46 return F.relu(x_j + edge_attr)
47  
48 def update(self, aggr_out):
49 return aggr_out
50  
51  
52 class GCNConvNew(MessagePassing):
53 def __init__(self, emb_dim, dataset_group):
54 super(GCNConvNew, self).__init__(aggr='add')
55  
56 self.linear = torch.nn.Linear(emb_dim, emb_dim)
57 self.root_emb = torch.nn.Embedding(1, emb_dim)
58  
  • E203 Whitespace before ':'
59 if dataset_group == 'mol' :
  • E251 Unexpected spaces around keyword / parameter equals (in 2 places)
60 self.edge_encoder = BondEncoder(emb_dim = emb_dim)
  • E203 Whitespace before ':'
61 else :
62 self.edge_encoder = torch.nn.Linear(7, emb_dim)
63  
64 def forward(self, x, edge_index, edge_attr):
65 x = self.linear(x)
66 edge_embedding = self.edge_encoder(edge_attr)
67  
68 row, col = edge_index
69  
  • E501 Line too long (84 > 79 characters)
70 # edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device)
71 deg = degree(row, x.size(0), dtype=x.dtype) + 1
72 deg_inv_sqrt = deg.pow(-0.5)
73 deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
74  
75 norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
76  
  • E501 Line too long (93 > 79 characters)
77 return self.propagate(edge_index, x=x, edge_attr=edge_embedding, norm=norm) + F.relu(
78 x + self.root_emb.weight) * 1. / deg.view(-1, 1)
79  
80 def message(self, x_j, edge_attr, norm):
81 return norm.view(-1, 1) * F.relu(x_j + edge_attr)
82  
83 def update(self, aggr_out):
84 return aggr_out
85  
86  
87 class ChebConvNew(MessagePassing):
88 def __init__(self, emb_dim: int, K: int, dataset_group: str,
89 normalization: Optional[str] = 'sym', bias: bool = True,
90 **kwargs):
91 kwargs.setdefault('aggr', 'add')
92 super(ChebConvNew, self).__init__(**kwargs)
93  
94 assert K > 0
95 assert normalization in [None, 'sym', 'rw'], 'Invalid normalization'
96  
97 self.root_emb = torch.nn.Embedding(1, emb_dim)
  • E203 Whitespace before ':'
98 if dataset_group == 'mol' :
  • E251 Unexpected spaces around keyword / parameter equals (in 2 places)
99 self.edge_encoder = BondEncoder(emb_dim = emb_dim)
  • E203 Whitespace before ':'
100 else :
101 self.edge_encoder = torch.nn.Linear(7, emb_dim)
102  
103 self.emb_dim = emb_dim
104 self.normalization = normalization
105 self.lins = torch.nn.ModuleList([
106 torch.nn.Linear(emb_dim, emb_dim, bias=False) for _ in range(K)
107 ])
108  
109 if bias:
110 self.bias = Parameter(torch.Tensor(emb_dim))
111 else:
112 self.register_parameter('bias', None)
113  
114 def __norm__(self, edge_index, num_nodes: Optional[int],
115 edge_weight: OptTensor, normalization: Optional[str],
116 lambda_max, dtype: Optional[int] = None,
117 batch: OptTensor = None):
118  
119 edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
120 edge_index, edge_weight = get_laplacian(edge_index, edge_weight,
121 normalization, dtype,
122 num_nodes)
123  
124 if batch is not None and lambda_max.numel() > 1:
125 lambda_max = lambda_max[batch[edge_index[0]]]
126  
127 edge_weight = (2.0 * edge_weight) / lambda_max
128 edge_weight.masked_fill_(edge_weight == float('inf'), 0)
129  
130 edge_index, edge_weight = add_self_loops(edge_index, edge_weight,
131 fill_value=-1.,
132 num_nodes=num_nodes)
133 assert edge_weight is not None
134 return edge_index, edge_weight
135  
136 def forward(self, x, edge_index, edge_attr: OptTensor = None,
137 batch: OptTensor = None, lambda_max: OptTensor = None):
138 """"""
139 edge_embedding = self.edge_encoder(edge_attr)
140  
141 if self.normalization != 'sym' and lambda_max is None:
142 raise ValueError('You need to pass `lambda_max` to `forward() in`'
143 'case the normalization is non-symmetric.')
144  
145 if lambda_max is None:
146 lambda_max = torch.tensor(2.0, dtype=x.dtype, device=x.device)
147 if not isinstance(lambda_max, torch.Tensor):
148 lambda_max = torch.tensor(lambda_max, dtype=x.dtype,
149 device=x.device)
150 assert lambda_max is not None
151  
152 edge_index, norm = self.__norm__(edge_index, x.size(self.node_dim),
153 None, self.normalization,
154 lambda_max, dtype=x.dtype,
155 batch=batch)
156  
  • E501 Line too long (83 > 79 characters)
157 # edge_index, norm = coalesce(edge_index, norm, m=x.shape[0], n=x.shape[0])
  • E501 Line too long (93 > 79 characters)
158 edge_index, norm, loop_edge_index, loop_norm = segregate_self_loops(edge_index, norm)
  • E501 Line too long (101 > 79 characters)
159 loop_edge_index, loop_norm = coalesce(loop_edge_index, loop_norm, m=x.shape[0], n=x.shape[0])
160  
161 Tx_0 = x
162 Tx_1 = x # Dummy.
163 out = self.lins[0](Tx_0)
164  
165 # propagate_type: (x: Tensor, norm: Tensor)
166 if len(self.lins) > 1:
  • E501 Line too long (100 > 79 characters)
  • E225 Missing whitespace around operator
167 Tx_1 = self.propagate(edge_index, x=x, edge_attr=edge_embedding, norm=norm, size=None)+\
  • E231 Missing whitespace after ','
168 loop_norm.view(-1,1)*F.relu(x + self.root_emb.weight)
169 out = out + self.lins[1](Tx_1)
170  
171 for lin in self.lins[2:]:
  • E501 Line too long (103 > 79 characters)
  • E225 Missing whitespace around operator
172 Tx_2 = self.propagate(edge_index, x=Tx_1, edge_attr=edge_embedding, norm=norm, size=None)+\
  • E231 Missing whitespace after ','
173 loop_norm.view(-1,1)*F.relu(Tx_1 + self.root_emb.weight)
174 Tx_2 = 2. * Tx_2 - Tx_0
175 out = out + lin.forward(Tx_2)
176 Tx_0, Tx_1 = Tx_1, Tx_2
177  
178 if self.bias is not None:
179 out += self.bias
180  
181 return out
182  
183 def message(self, x_j, edge_attr, norm):
184 return norm.view(-1, 1) * F.relu(x_j+edge_attr)