Line too long (92 > 79 characters):
3 from torch_geometric.nn import global_mean_pool, global_add_pool, GCNConv, GINConv, ChebConvLine too long (100 > 79 characters):
14 Graph Isomorphism Network augmented with virtual node for multi-task binary graph classificationLine too long (82 > 79 characters):
18 - prediction (Tensor): float torch tensor of shape (num_graphs, num_tasks)Line too long (118 > 79 characters):
21 def __init__(self, gnn_type, dataset_group, num_tasks=128, num_layers=5, emb_dim=300, dropout=0.5, is_pooled=True,Line too long (109 > 79 characters):
25 - num_tasks (int): number of binary label tasks. default to 128 (number of tasks of ogbg-molpcba)Line too long (103 > 79 characters):
56 self.gnn_node = GNN_node_Virtualnode(num_layers, emb_dim, dataset_group=self.dataset_group,Line too long (103 > 79 characters):
57 gnn_type=self.gnn_type.split('_')[0], dropout=dropout,Line too long (91 > 79 characters):
60 self.gnn_node = GNN_node(num_layers, emb_dim, dataset_group=self.dataset_group,Line too long (91 > 79 characters):
61 gnn_type=self.gnn_type.split('_')[0], dropout=dropout,Line too long (82 > 79 characters):
73 self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_tasks)Line too long (80 > 79 characters):
84 return self.graph_pred_linear(self.pool(h_node, batched_data.batch))Line too long (103 > 79 characters):
94 def __init__(self, num_layer, emb_dim, dataset_group='mol', gnn_type='gin', dropout=0.5, JK="last",Line too long (93 > 79 characters):
119 self.node_encoder = torch.nn.Embedding(1, emb_dim) # uniform input node embeddingLine too long (90 > 79 characters):
138 if self.dataset_group in ['RotatedMNIST', 'ColoredMNIST', 'SBM', 'UPFD']:Line too long (84 > 79 characters):
139 mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim),Line too long (97 > 79 characters):
140 torch.nn.BatchNorm1d(2 * emb_dim), torch.nn.ReLU(),Line too long (84 > 79 characters):
141 torch.nn.Linear(2 * emb_dim, emb_dim))Line too long (90 > 79 characters):
146 if self.dataset_group in ['RotatedMNIST', 'ColoredMNIST', 'SBM', 'UPFD']:Line too long (90 > 79 characters):
151 if self.dataset_group in ['RotatedMNIST', 'ColoredMNIST', 'SBM', 'UPFD']:Line too long (87 > 79 characters):
154 self.convs.append(ChebConvNew(emb_dim, Cheb_K, self.dataset_group))Line too long (81 > 79 characters):
156 raise ValueError('Undefined GNN type called {}'.format(gnn_type))Line too long (125 > 79 characters):
164 x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batchLine too long (236 > 79 characters):
171 # h_list = [self.node_encoder(x[:,:2]) + self.node_encoder_cate(x[:,2:].to(torch.int).squeeze()) + perturb if perturb is not None else self.node_encoder(x[:,:2]) + self.node_encoder_cate(x[:,2:].to(torch.int).squeeze())]Line too long (102 > 79 characters):
173 h_list = [self.node_encoder(x) + perturb if perturb is not None else self.node_encoder(x)]Line too long (81 > 79 characters):
185 h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)Line too long (103 > 79 characters):
210 def __init__(self, num_layer, emb_dim, dataset_group='mol', gnn_type='gin', dropout=0.5, JK="last",Line too long (93 > 79 characters):
234 self.node_encoder = torch.nn.Embedding(1, emb_dim) # uniform input node embeddingLine too long (90 > 79 characters):
261 if self.dataset_group in ['RotatedMNIST', 'ColoredMNIST', 'SBM', 'UPFD']:Line too long (84 > 79 characters):
262 mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim),Line too long (97 > 79 characters):
263 torch.nn.BatchNorm1d(2 * emb_dim), torch.nn.ReLU(),Line too long (84 > 79 characters):
264 torch.nn.Linear(2 * emb_dim, emb_dim))Line too long (90 > 79 characters):
269 if self.dataset_group in ['RotatedMNIST', 'ColoredMNIST', 'SBM', 'UPFD']:Line too long (90 > 79 characters):
274 if self.dataset_group in ['RotatedMNIST', 'ColoredMNIST', 'SBM', 'UPFD']:Line too long (87 > 79 characters):
277 self.convs.append(ChebConvNew(emb_dim, Cheb_K, self.dataset_group))Line too long (81 > 79 characters):
279 raise ValueError('Undefined GNN type called {}'.format(gnn_type))Line too long (109 > 79 characters):
285 torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.BatchNorm1d(2 * emb_dim),Line too long (105 > 79 characters):
287 torch.nn.Linear(2 * emb_dim, emb_dim), torch.nn.BatchNorm1d(emb_dim),Line too long (125 > 79 characters):
295 x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batchLine too long (89 > 79 characters):
299 torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device))Line too long (236 > 79 characters):
306 # h_list = [self.node_encoder(x[:,:2]) + self.node_encoder_cate(x[:,2:].to(torch.int).squeeze()) + perturb if perturb is not None else self.node_encoder(x[:,:2]) + self.node_encoder_cate(x[:,2:].to(torch.int).squeeze())]Line too long (102 > 79 characters):
308 h_list = [self.node_encoder(x) + perturb if perturb is not None else self.node_encoder(x)]Line too long (81 > 79 characters):
322 h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)Line too long (106 > 79 characters):
332 virtualnode_embedding_temp = global_add_pool(h_list[layer], batch) + virtualnode_embeddingLine too long (102 > 79 characters):
337 self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio,Line too long (115 > 79 characters):
340 virtualnode_embedding = F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp),Line too long (94 > 79 characters):
341 self.drop_ratio, training=self.training)