⬅ datasets/ogbmolhiv_dataset.py source

1 import os
2 from collections import namedtuple
3 import numpy as np
4 import torch
5 import torch_geometric
6 from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
7 from torch_geometric.data import download_url, extract_zip
8 from torch_geometric.data.dataloader import Collater as PyGCollater
  • F811 Redefinition of unused 'namedtuple' from line 2
  • F401 'collections.namedtuple' imported but unused
9 from collections import namedtuple
  • F401 'tqdm' imported but unused
10 import tqdm
11 from gds.datasets.gds_dataset import GDSDataset
12 from torch_geometric.utils import to_dense_adj
13 import torch.nn.functional as F
14  
15  
16 class OGBHIVDataset(GDSDataset):
17 """
18 The OGB-molhiv dataset.
  • E501 Line too long (102 > 79 characters)
19 This dataset is directly adopted from Open Graph Benchmark, and originally curated by MoleculeNet.
20  
21 Supported `split_scheme`:
22 - 'official' or 'scaffold', which are equivalent
23  
24 Input (x):
25 Molecular graphs represented as Pytorch Geometric data objects
26  
27 Label (y):
28 y represents 128-class binary labels.
29  
30 Metadata:
31 - scaffold
  • E501 Line too long (93 > 79 characters)
32 Each molecule is annotated with the scaffold ID that the molecule is assigned to.
33  
34 Website:
35 https://ogb.stanford.edu/docs/graphprop/#ogbg-mol
36  
37 Original publication:
38 @article{hu2020ogb,
  • E501 Line too long (82 > 79 characters)
39 title={Open Graph Benchmark: Datasets for Machine Learning on Graphs},
  • E501 Line too long (112 > 79 characters)
40 author={W. {Hu}, M. {Fey}, M. {Zitnik}, Y. {Dong}, H. {Ren}, B. {Liu}, M. {Catasta}, J. {Leskovec}},
41 journal={arXiv preprint arXiv:2005.00687},
42 year={2020}
43 }
44  
45 @article{wu2018moleculenet,
46 title={MoleculeNet: a benchmark for molecular machine learning},
  • E501 Line too long (129 > 79 characters)
47 author={Z. {Wu}, B. {Ramsundar}, E. V {Feinberg}, J. {Gomes}, C. {Geniesse}, A. S {Pappu}, K. {Leswing}, V. {Pande}},
48 journal={Chemical science},
49 volume={9},
50 number={2},
51 pages={513--530},
52 year={2018},
53 publisher={Royal Society of Chemistry}
54 }
55  
56 License:
57 This dataset is distributed under the MIT license.
58 https://github.com/snap-stanford/ogb/blob/master/LICENSE
59 """
60  
61 _dataset_name = 'ogb-molhiv'
62 _versions_dict = {
63 '1.0': {
64 'download_url': None,
65 'compressed_size': None}}
66  
  • E501 Line too long (114 > 79 characters)
67 def __init__(self, version=None, root_dir='data', download=False, split_scheme='official', random_split=False,
68 subgraph=False, **dataset_kwargs):
69 self._version = version
70 if version is not None:
  • E501 Line too long (118 > 79 characters)
71 raise ValueError('Versioning for OGB-MolHIV is handled through the OGB package. Please set version=none.')
72 # internally call ogb package
  • E501 Line too long (85 > 79 characters)
73 self.ogb_dataset = PygGraphPropPredDataset(name='ogbg-molhiv', root=root_dir)
74  
75 # set variables
76 self._data_dir = self.ogb_dataset.root
77 if split_scheme == 'official':
78 split_scheme = 'scaffold'
79 self._split_scheme = split_scheme
  • E501 Line too long (138 > 79 characters)
80 self._y_type = 'float' # although the task is binary classification, the prediction target contains nan value, thus we need float
81 self._y_size = self.ogb_dataset.num_tasks
82 # self._n_classes = self.ogb_dataset.__num_classes__
83 self._n_classes = 1
84  
85 self._split_array = torch.zeros(len(self.ogb_dataset)).long()
86 split_idx = self.ogb_dataset.get_idx_split()
87  
88 np.random.seed(0)
89 dataset_size = len(self.ogb_dataset)
90 if random_split:
91 random_index = np.random.permutation(dataset_size)
92 train_split_idx = random_index[:len(split_idx['train'])]
  • E501 Line too long (115 > 79 characters)
93 val_split_idx = random_index[len(split_idx['train']):len(split_idx['train']) + len(split_idx['valid'])]
  • E501 Line too long (93 > 79 characters)
94 test_split_idx = random_index[len(split_idx['train']) + len(split_idx['valid']):]
95 else:
96 train_split_idx = split_idx['train']
97 val_split_idx = split_idx['valid']
98 test_split_idx = split_idx['test']
99  
100 self._split_array[train_split_idx] = 0
101 self._split_array[val_split_idx] = 1
102 self._split_array[test_split_idx] = 2
103  
104  
  • E303 Too many blank lines (2)
105 self._y_array = self.ogb_dataset.data.y
106 self._metadata_fields = ['scaffold', 'y']
107  
  • E501 Line too long (95 > 79 characters)
108 metadata_file_path = os.path.join(self.ogb_dataset.root, 'raw', 'OGB-MolHIV_group.npy')
109 if not os.path.exists(metadata_file_path):
110 metadata_zip_file_path = download_url(
  • E501 Line too long (112 > 79 characters)
111 'https://www.dropbox.com/s/i5z388zxbut0quo/OGB-MolHIV_group.zip?dl=1', self.ogb_dataset.raw_dir)
112 extract_zip(metadata_zip_file_path, self.ogb_dataset.raw_dir)
113 os.unlink(metadata_zip_file_path)
  • E501 Line too long (103 > 79 characters)
114 self._metadata_array_wo_y = torch.from_numpy(np.load(metadata_file_path)).reshape(-1, 1).long()
  • E501 Line too long (97 > 79 characters)
115 self._metadata_array = torch.cat((self._metadata_array_wo_y, self.ogb_dataset.data.y), 1)
116  
117 if dataset_kwargs['model'] == '3wlgnn':
118 self._collate = self.collate_dense
119 else:
120 if torch_geometric.__version__ >= '1.7.0':
121 self._collate = PyGCollater(follow_batch=[], exclude_keys=[])
122 else:
123 self._collate = PyGCollater(follow_batch=[])
124  
125 self._metric = Evaluator('ogbg-molhiv')
126  
127 # GSN
128 self.subgraph = subgraph
129 if self.subgraph:
130 self.id_type = dataset_kwargs['gsn_id_type']
131 self.k = dataset_kwargs['gsn_k']
132 from gds.datasets.gsn.gsn_data_prep import GSN
  • E501 Line too long (111 > 79 characters)
133 subgraph = GSN(dataset_name='ogbg-molhiv', dataset_group='ogb', induced=True, id_type=self.id_type,
134 k=self.k)
  • E501 Line too long (116 > 79 characters)
135 self.graphs_ptg, self.encoder_ids, self.d_id, self.d_degree = subgraph.preprocess(self.ogb_dataset.root)
136  
137 if self.graphs_ptg[0].x.dim() == 1:
138 self.num_features = 1
139 else:
140 self.num_features = self.graphs_ptg[0].num_features
141  
142 if hasattr(self.graphs_ptg[0], 'edge_features'):
143 if self.graphs_ptg[0].edge_features.dim() == 1:
144 self.num_edge_features = 1
145 else:
  • E501 Line too long (86 > 79 characters)
146 self.num_edge_features = self.graphs_ptg[0].edge_features.shape[1]
147 else:
148 self.num_edge_features = None
149  
150 self.d_in_node_encoder = [self.num_features]
151 self.d_in_edge_encoder = [self.num_edge_features]
152  
153 super().__init__(root_dir, download, split_scheme)
154  
155 def get_input(self, idx):
156 if self.subgraph:
157 return self.graphs_ptg[int(idx)]
158 else:
159 return self.ogb_dataset[int(idx)]
160  
161 def eval(self, y_pred, y_true, metadata, prediction_fn=None):
162 """
163 Computes all evaluation metrics.
164 Args:
165 - y_pred (FloatTensor): Binary logits from a model
166 - y_true (LongTensor): Ground-truth labels
167 - metadata (Tensor): Metadata
  • E501 Line too long (91 > 79 characters)
168 - prediction_fn (function): A function that turns y_pred into predicted labels.
  • E501 Line too long (106 > 79 characters)
169 Only None is supported because OGB Evaluators accept binary logits
170 Output:
171 - results (dictionary): Dictionary of evaluation metrics
172 - results_str (str): String summarizing the evaluation metrics
173 """
  • E501 Line too long (120 > 79 characters)
174 assert prediction_fn is None, "OGBHIVDataset.eval() does not support prediction_fn. Only binary logits accepted"
175 input_dict = {"y_true": y_true, "y_pred": y_pred}
  • W293 Blank line contains whitespace
176
177 results = self._metric.eval(input_dict)
178  
179 return results, f"ROCAUC: {results['rocauc']:.3f}\n"
180  
181 # prepare dense tensors for GNNs using them; such as RingGNN, 3WLGNN
182 def collate_dense(self, samples):
183 def _sym_normalize_adj(adjacency):
184 deg = torch.sum(adjacency, dim=0) # .squeeze()
  • E501 Line too long (89 > 79 characters)
185 deg_inv = torch.where(deg > 0, 1. / torch.sqrt(deg), torch.zeros(deg.size()))
186 deg_inv = torch.diag(deg_inv)
187 return torch.mm(deg_inv, torch.mm(adjacency, deg_inv))
188  
189 # The input samples is a list of pairs (graph, label).
190 node_feat_space = torch.tensor([119, 4, 12, 12, 10, 6, 6, 2, 2])
191 edge_feat_space = torch.tensor([5, 6, 2])
192  
193 graph_list, y_list, metadata_list = map(list, zip(*samples))
194 y, metadata = torch.tensor(y_list), torch.stack(metadata_list)
195  
196 # insert size one at dim 0 because this dataset's y is 1d
197 y = y.unsqueeze(0)
198  
199 feat = []
200 for graph in graph_list:
  • E501 Line too long (109 > 79 characters)
201 adj = _sym_normalize_adj(to_dense_adj(graph.edge_index, max_num_nodes=graph.x.size(0)).squeeze())
202 zero_adj = torch.zeros_like(adj)
203 in_dim = node_feat_space.sum() + edge_feat_space.sum()
204  
205 # use node feats to prepare adj
206 adj_feat = torch.stack([zero_adj for _ in range(in_dim)])
207 adj_feat = torch.cat([adj.unsqueeze(0), adj_feat], dim=0)
208  
209 def convert(feature, space):
210 out = []
211 for i, label in enumerate(feature):
212 out.append(F.one_hot(label, space[i]))
213 return torch.cat(out)
214  
215 for node, node_feat in enumerate(graph.x):
  • E501 Line too long (103 > 79 characters)
216 adj_feat[1:1 + node_feat_space.sum(), node, node] = convert(node_feat, node_feat_space)
217 for edge in range(graph.edge_index.shape[1]):
  • E501 Line too long (85 > 79 characters)
218 target, source = graph.edge_index[0][edge], graph.edge_index[1][edge]
219 edge_feat = graph.edge_attr[edge]
  • E501 Line too long (106 > 79 characters)
220 adj_feat[1 + node_feat_space.sum():, target, source] = convert(edge_feat, edge_feat_space)
221  
222 feat.append(adj_feat)
223  
224 feat = torch.stack(feat)
225 return feat, y, metadata