⬅ datasets/ogbgppa_dataset.py source

1 import os
2 import numpy as np
3  
4 import pandas as pd
5 import torch
6 import torch_geometric
7 from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
8 from torch_geometric.data.dataloader import Collater as PyGCollater
9  
10 from gds.datasets.gds_dataset import GDSDataset
11 from torch_geometric.utils import to_dense_adj
12  
13  
14 def add_zeros(data):
15 data.x = torch.zeros(data.num_nodes, dtype=torch.long)
16 return data
17  
18  
19 class OGBGPPADataset(GDSDataset):
20 """
21 The OGB-ppa dataset.
  • E501 Line too long (102 > 79 characters)
22 This dataset is directly adopted from Open Graph Benchmark, and originally curated by MoleculeNet.
23  
24 Supported `split_scheme`:
25 - 'official' or 'scaffold', which are equivalent
26  
27 Input (x):
28 Molecular graphs represented as Pytorch Geometric data objects
29  
30 Label (y):
31 y represents 128-class binary labels.
32  
33 Metadata:
34 - scaffold
  • E501 Line too long (93 > 79 characters)
35 Each molecule is annotated with the scaffold ID that the molecule is assigned to.
36  
37 Website:
38 https://ogb.stanford.edu/docs/graphprop/#ogbg-mol
39  
40 Original publication:
41 @article{hu2020ogb,
  • E501 Line too long (82 > 79 characters)
42 title={Open Graph Benchmark: Datasets for Machine Learning on Graphs},
  • E501 Line too long (112 > 79 characters)
43 author={W. {Hu}, M. {Fey}, M. {Zitnik}, Y. {Dong}, H. {Ren}, B. {Liu}, M. {Catasta}, J. {Leskovec}},
44 journal={arXiv preprint arXiv:2005.00687},
45 year={2020}
46 }
47  
48 @article{wu2018moleculenet,
49 title={MoleculeNet: a benchmark for molecular machine learning},
  • E501 Line too long (129 > 79 characters)
50 author={Z. {Wu}, B. {Ramsundar}, E. V {Feinberg}, J. {Gomes}, C. {Geniesse}, A. S {Pappu}, K. {Leswing}, V. {Pande}},
51 journal={Chemical science},
52 volume={9},
53 number={2},
54 pages={513--530},
55 year={2018},
56 publisher={Royal Society of Chemistry}
57 }
58  
59 License:
60 This dataset is distributed under the MIT license.
61 https://github.com/snap-stanford/ogb/blob/master/LICENSE
62 """
63  
64 _dataset_name = 'ogbg-ppa'
65 _versions_dict = {
66 '1.0': {
67 'download_url': None,
68 'compressed_size': None}}
69  
  • E501 Line too long (114 > 79 characters)
70 def __init__(self, version=None, root_dir='data', download=False, split_scheme='official', random_split=False,
71 subgraph=False, **dataset_kwargs):
72 self._version = version
73 if version is not None:
  • E501 Line too long (115 > 79 characters)
74 raise ValueError('Versioning for OGB-PPA is handled through the OGB package. Please set version=none.')
75 # internally call ogb package
  • E501 Line too long (103 > 79 characters)
76 self.ogb_dataset = PygGraphPropPredDataset(name='ogbg-ppa', root=root_dir, transform=add_zeros)
77  
78 # set variables
79 self._data_dir = self.ogb_dataset.root
80 if split_scheme == 'official':
81 split_scheme = 'scaffold'
82 self._split_scheme = split_scheme
  • E501 Line too long (138 > 79 characters)
83 self._y_type = 'float' # although the task is binary classification, the prediction target contains nan value, thus we need float
84 self._y_size = self.ogb_dataset.num_tasks
85 self._n_classes = self.ogb_dataset.__num_classes__
86  
87 self._split_array = torch.zeros(len(self.ogb_dataset)).long()
88 split_idx = self.ogb_dataset.get_idx_split()
89  
90 np.random.seed(0)
91 dataset_size = len(self.ogb_dataset)
92 if random_split:
93 random_index = np.random.permutation(dataset_size)
94 train_split_idx = random_index[:len(split_idx['train'])]
  • E501 Line too long (113 > 79 characters)
95 val_split_idx = random_index[len(split_idx['train']):len(split_idx['train'])+len(split_idx['valid'])]
  • E501 Line too long (91 > 79 characters)
96 test_split_idx = random_index[len(split_idx['train'])+len(split_idx['valid']):]
  • E203 Whitespace before ':'
97 else :
98 train_split_idx = split_idx['train']
99 val_split_idx = split_idx['valid']
100 test_split_idx = split_idx['test']
101  
102 self._split_array[train_split_idx] = 0
103 self._split_array[val_split_idx] = 1
104 self._split_array[test_split_idx] = 2
105  
106  
  • E303 Too many blank lines (2)
107 self._y_array = self.ogb_dataset.data.y.squeeze()
108  
109 self._metadata_fields = ['species', 'y']
  • E501 Line too long (104 > 79 characters)
110 metadata_file_path = os.path.join(self.ogb_dataset.root, 'mapping', 'graphidx2speciesid.csv.gz')
111 df = pd.read_csv(metadata_file_path)
112 df_np = df['species id'].to_numpy()
  • E501 Line too long (95 > 79 characters)
113 self._metadata_array_wo_y = torch.from_numpy(df_np - df_np.min()).reshape(-1, 1).long()
  • E501 Line too long (97 > 79 characters)
114 self._metadata_array = torch.cat((self._metadata_array_wo_y, self.ogb_dataset.data.y), 1)
115  
116 if dataset_kwargs['model'] == '3wlgnn':
117 self._collate = self.collate_dense
118 else:
119 if torch_geometric.__version__ >= '1.7.0':
120 self._collate = PyGCollater(follow_batch=[], exclude_keys=[])
121 else:
122 self._collate = PyGCollater(follow_batch=[])
123  
124 self._metric = Evaluator('ogbg-ppa')
125  
126  
  • E303 Too many blank lines (2)
127 # GSN
128 self.subgraph = subgraph
129 if self.subgraph:
130 self.id_type = dataset_kwargs['gsn_id_type']
131 self.k = dataset_kwargs['gsn_k']
132 from gds.datasets.gsn.gsn_data_prep import GSN
  • E501 Line too long (108 > 79 characters)
133 subgraph = GSN(dataset_name='ogbg-ppa', dataset_group='ogb', induced=True, id_type=self.id_type,
134 k=self.k)
  • E501 Line too long (116 > 79 characters)
135 self.graphs_ptg, self.encoder_ids, self.d_id, self.d_degree = subgraph.preprocess(self.ogb_dataset.root)
136  
137 if self.graphs_ptg[0].x.dim() == 1:
138 self.num_features = 1
139 else:
140 self.num_features = self.graphs_ptg[0].num_features
141  
142 if hasattr(self.graphs_ptg[0], 'edge_features'):
143 if self.graphs_ptg[0].edge_features.dim() == 1:
144 self.num_edge_features = 1
145 else:
  • E501 Line too long (86 > 79 characters)
146 self.num_edge_features = self.graphs_ptg[0].edge_features.shape[1]
147 else:
148 self.num_edge_features = None
149  
150 self.d_in_node_encoder = [self.num_features]
151 self.d_in_edge_encoder = [self.num_edge_features]
152  
153 super().__init__(root_dir, download, split_scheme)
154  
155 def get_input(self, idx):
156 if self.subgraph:
157 return self.graphs_ptg[int(idx)]
158 else:
159 return self.ogb_dataset[int(idx)]
160  
161 def eval(self, y_pred, y_true, metadata, prediction_fn=None):
162 """
163 Computes all evaluation metrics.
164 Args:
165 - y_pred (FloatTensor): Binary logits from a model
166 - y_true (LongTensor): Ground-truth labels
167 - metadata (Tensor): Metadata
  • E501 Line too long (91 > 79 characters)
168 - prediction_fn (function): A function that turns y_pred into predicted labels.
  • E501 Line too long (106 > 79 characters)
169 Only None is supported because OGB Evaluators accept binary logits
170 Output:
171 - results (dictionary): Dictionary of evaluation metrics
172 - results_str (str): String summarizing the evaluation metrics
173 """
  • E501 Line too long (120 > 79 characters)
174 assert prediction_fn is None, "OGBHIVDataset.eval() does not support prediction_fn. Only binary logits accepted"
175 y_true = y_true.view(-1, 1)
176 y_pred = torch.argmax(y_pred.detach(), dim=1).view(-1, 1)
177 input_dict = {"y_true": y_true, "y_pred": y_pred}
178  
179 results = self._metric.eval(input_dict)
180  
181 return results, f"Accuracy: {results['acc']:.3f}\n"
182  
183 # prepare dense tensors for GNNs using them; such as RingGNN, 3WLGNN
184 def collate_dense(self, samples):
185 def _sym_normalize_adj(adjacency):
186 deg = torch.sum(adjacency, dim=0) # .squeeze()
  • E501 Line too long (89 > 79 characters)
187 deg_inv = torch.where(deg > 0, 1. / torch.sqrt(deg), torch.zeros(deg.size()))
188 deg_inv = torch.diag(deg_inv)
189 return torch.mm(deg_inv, torch.mm(adjacency, deg_inv))
190  
191 # The input samples is a list of pairs (graph, label).
192 graph_list, y_list, metadata_list = map(list, zip(*samples))
193 y, metadata = torch.tensor(y_list), torch.stack(metadata_list)
194  
195 x_edge_feat = []
196 for graph in graph_list:
  • E501 Line too long (109 > 79 characters)
197 adj = _sym_normalize_adj(to_dense_adj(graph.edge_index, max_num_nodes=graph.x.size(0)).squeeze())
198 zero_adj = torch.zeros_like(adj)
199 in_dim = graph.edge_attr.shape[1]
200  
201 # use node feats to prepare adj
202 adj_edge_feat = torch.stack([zero_adj for _ in range(in_dim)])
203 adj_edge_feat = torch.cat([adj.unsqueeze(0), adj_edge_feat], dim=0)
204  
205 for edge in range(graph.edge_index.shape[1]):
  • E501 Line too long (85 > 79 characters)
206 target, source = graph.edge_index[0][edge], graph.edge_index[1][edge]
207 adj_edge_feat[1:, target, source] = graph.edge_attr[edge]
208  
209 x_edge_feat.append(adj_edge_feat)
210  
211 x_edge_feat = torch.stack(x_edge_feat)
212 return x_edge_feat, y, metadata