⬅ datasets/ogbmolpcba_dataset.py source

1 import os
2  
3 import numpy as np
4 import torch
5 import torch_geometric
6 from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
7 from torch_geometric.data import download_url, extract_zip
8 from torch_geometric.data.dataloader import Collater as PyGCollater
9  
10 from gds.datasets.gds_dataset import GDSDataset
11 from torch_geometric.utils import to_dense_adj
12 import torch.nn.functional as F
13  
14  
15 class OGBPCBADataset(GDSDataset):
16 """
17 The OGB-molpcba dataset.
  • E501 Line too long (102 > 79 characters)
18 This dataset is directly adopted from Open Graph Benchmark, and originally curated by MoleculeNet.
19  
20 Supported `split_scheme`:
21 - 'official' or 'scaffold', which are equivalent
22  
23 Input (x):
24 Molecular graphs represented as Pytorch Geometric data objects
25  
26 Label (y):
27 y represents 128-class binary labels.
28  
29 Metadata:
30 - scaffold
  • E501 Line too long (93 > 79 characters)
31 Each molecule is annotated with the scaffold ID that the molecule is assigned to.
32  
33 Website:
34 https://ogb.stanford.edu/docs/graphprop/#ogbg-mol
35  
36 Original publication:
37 @article{hu2020ogb,
  • E501 Line too long (82 > 79 characters)
38 title={Open Graph Benchmark: Datasets for Machine Learning on Graphs},
  • E501 Line too long (112 > 79 characters)
39 author={W. {Hu}, M. {Fey}, M. {Zitnik}, Y. {Dong}, H. {Ren}, B. {Liu}, M. {Catasta}, J. {Leskovec}},
40 journal={arXiv preprint arXiv:2005.00687},
41 year={2020}
42 }
43  
44 @article{wu2018moleculenet,
45 title={MoleculeNet: a benchmark for molecular machine learning},
  • E501 Line too long (129 > 79 characters)
46 author={Z. {Wu}, B. {Ramsundar}, E. V {Feinberg}, J. {Gomes}, C. {Geniesse}, A. S {Pappu}, K. {Leswing}, V. {Pande}},
47 journal={Chemical science},
48 volume={9},
49 number={2},
50 pages={513--530},
51 year={2018},
52 publisher={Royal Society of Chemistry}
53 }
54  
55 License:
56 This dataset is distributed under the MIT license.
57 https://github.com/snap-stanford/ogb/blob/master/LICENSE
58 """
59  
60 _dataset_name = 'ogb-molpcba'
61 _versions_dict = {
62 '1.0': {
63 'download_url': None,
64 'compressed_size': None}}
65  
  • E501 Line too long (114 > 79 characters)
66 def __init__(self, version=None, root_dir='data', download=False, split_scheme='official', random_split=False,
  • E128 Continuation line under-indented for visual indent
  • E125 Continuation line with same indent as next logical line
67 subgraph=False, **dataset_kwargs):
68 self._version = version
69 if version is not None:
  • E501 Line too long (119 > 79 characters)
70 raise ValueError('Versioning for OGB-MolPCBA is handled through the OGB package. Please set version=none.')
71 # internally call ogb package
  • E501 Line too long (86 > 79 characters)
72 self.ogb_dataset = PygGraphPropPredDataset(name='ogbg-molpcba', root=root_dir)
73  
74 # set variables
75 self._data_dir = self.ogb_dataset.root
76 if split_scheme == 'official':
77 split_scheme = 'scaffold'
78 self._split_scheme = split_scheme
  • E501 Line too long (138 > 79 characters)
79 self._y_type = 'float' # although the task is binary classification, the prediction target contains nan value, thus we need float
80 self._y_size = self.ogb_dataset.num_tasks
81 self._n_classes = self.ogb_dataset.__num_classes__
82  
83  
  • E303 Too many blank lines (2)
84 self._split_array = torch.zeros(len(self.ogb_dataset)).long()
85 split_idx = self.ogb_dataset.get_idx_split()
86  
87 np.random.seed(0)
88 dataset_size = len(self.ogb_dataset)
89 if random_split:
90 random_index = np.random.permutation(dataset_size)
91 train_split_idx = random_index[:len(split_idx['train'])]
  • E501 Line too long (115 > 79 characters)
92 val_split_idx = random_index[len(split_idx['train']):len(split_idx['train']) + len(split_idx['valid'])]
  • E501 Line too long (93 > 79 characters)
93 test_split_idx = random_index[len(split_idx['train']) + len(split_idx['valid']):]
94 else:
95 train_split_idx = split_idx['train']
96 val_split_idx = split_idx['valid']
97 test_split_idx = split_idx['test']
98  
99 self._split_array[train_split_idx] = 0
100 self._split_array[val_split_idx] = 1
101 self._split_array[test_split_idx] = 2
102  
103  
  • E303 Too many blank lines (2)
104 self._y_array = self.ogb_dataset.data.y
105 self._metadata_fields = ['scaffold']
106  
  • E501 Line too long (96 > 79 characters)
107 metadata_file_path = os.path.join(self.ogb_dataset.root, 'raw', 'OGB-MolPCBA_group.npy')
108 if not os.path.exists(metadata_file_path):
109 metadata_zip_file_path = download_url(
  • E501 Line too long (113 > 79 characters)
110 'https://www.dropbox.com/s/jjpr6tw34llxbpw/OGB-MolPCBA_group.zip?dl=1', self.ogb_dataset.raw_dir)
111 extract_zip(metadata_zip_file_path, self.ogb_dataset.raw_dir)
112 os.unlink(metadata_zip_file_path)
  • E501 Line too long (98 > 79 characters)
113 self._metadata_array = torch.from_numpy(np.load(metadata_file_path)).reshape(-1, 1).long()
114  
115 if dataset_kwargs['model'] == '3wlgnn':
116 self._collate = self.collate_dense
117 else:
118 if torch_geometric.__version__ >= '1.7.0':
119 self._collate = PyGCollater(follow_batch=[], exclude_keys=[])
120 else:
121 self._collate = PyGCollater(follow_batch=[])
122  
123 self._metric = Evaluator('ogbg-molpcba')
124  
125  
  • E303 Too many blank lines (2)
126 # GSN
127 self.subgraph = subgraph
128 if self.subgraph:
129 self.id_type = dataset_kwargs['gsn_id_type']
130 self.k = dataset_kwargs['gsn_k']
131 from gds.datasets.gsn.gsn_data_prep import GSN
  • E501 Line too long (111 > 79 characters)
132 subgraph = GSN(dataset_name='ogbg-molhiv', dataset_group='ogb', induced=True, id_type=self.id_type,
133 k=self.k)
  • E501 Line too long (116 > 79 characters)
134 self.graphs_ptg, self.encoder_ids, self.d_id, self.d_degree = subgraph.preprocess(self.ogb_dataset.root)
135  
136 if self.graphs_ptg[0].x.dim() == 1:
137 self.num_features = 1
138 else:
139 self.num_features = self.graphs_ptg[0].num_features
140  
141 if hasattr(self.graphs_ptg[0], 'edge_features'):
142 if self.graphs_ptg[0].edge_features.dim() == 1:
143 self.num_edge_features = 1
144 else:
  • E501 Line too long (86 > 79 characters)
145 self.num_edge_features = self.graphs_ptg[0].edge_features.shape[1]
146 else:
147 self.num_edge_features = None
148  
149 self.d_in_node_encoder = [self.num_features]
150 self.d_in_edge_encoder = [self.num_edge_features]
151  
152 super().__init__(root_dir, download, split_scheme)
153  
154 def get_input(self, idx):
155 if self.subgraph:
156 return self.graphs_ptg[int(idx)]
157 else:
158 return self.ogb_dataset[int(idx)]
159  
160 def eval(self, y_pred, y_true, metadata, prediction_fn=None):
161 """
162 Computes all evaluation metrics.
163 Args:
164 - y_pred (FloatTensor): Binary logits from a model
165 - y_true (LongTensor): Ground-truth labels
166 - metadata (Tensor): Metadata
  • E501 Line too long (91 > 79 characters)
167 - prediction_fn (function): A function that turns y_pred into predicted labels.
  • E501 Line too long (106 > 79 characters)
168 Only None is supported because OGB Evaluators accept binary logits
169 Output:
170 - results (dictionary): Dictionary of evaluation metrics
171 - results_str (str): String summarizing the evaluation metrics
172 """
  • E501 Line too long (121 > 79 characters)
173 assert prediction_fn is None, "OGBPCBADataset.eval() does not support prediction_fn. Only binary logits accepted"
174 input_dict = {"y_true": y_true, "y_pred": y_pred}
175 results = self._metric.eval(input_dict)
176  
177 return results, f"Average precision: {results['ap']:.3f}\n"
178  
179 # prepare dense tensors for GNNs using them; such as RingGNN, 3WLGNN
180 def collate_dense(self, samples):
181 def _sym_normalize_adj(adjacency):
182 deg = torch.sum(adjacency, dim=0) # .squeeze()
  • E501 Line too long (89 > 79 characters)
183 deg_inv = torch.where(deg > 0, 1. / torch.sqrt(deg), torch.zeros(deg.size()))
184 deg_inv = torch.diag(deg_inv)
185 return torch.mm(deg_inv, torch.mm(adjacency, deg_inv))
186  
187 # The input samples is a list of pairs (graph, label).
188 node_feat_space = torch.tensor([119, 4, 12, 12, 10, 6, 6, 2, 2])
189 edge_feat_space = torch.tensor([5, 6, 2])
190  
191 graph_list, y_list, metadata_list = map(list, zip(*samples))
192 # multi-task y, use torch.stack instead of torch.tensor
193 y, metadata = torch.stack(y_list), torch.stack(metadata_list)
194  
195 feat = []
196 for graph in graph_list:
  • E501 Line too long (109 > 79 characters)
197 adj = _sym_normalize_adj(to_dense_adj(graph.edge_index, max_num_nodes=graph.x.size(0)).squeeze())
198 zero_adj = torch.zeros_like(adj)
199 in_dim = node_feat_space.sum() + edge_feat_space.sum()
200  
201 # use node feats to prepare adj
202 adj_feat = torch.stack([zero_adj for _ in range(in_dim)])
203 adj_feat = torch.cat([adj.unsqueeze(0), adj_feat], dim=0)
204  
205 def convert(feature, space):
206 out = []
207 for i, label in enumerate(feature):
208 out.append(F.one_hot(label, space[i]))
209 return torch.cat(out)
210  
211 for node, node_feat in enumerate(graph.x):
  • E501 Line too long (103 > 79 characters)
212 adj_feat[1:1 + node_feat_space.sum(), node, node] = convert(node_feat, node_feat_space)
213 for edge in range(graph.edge_index.shape[1]):
  • E501 Line too long (85 > 79 characters)
214 target, source = graph.edge_index[0][edge], graph.edge_index[1][edge]
215 edge_feat = graph.edge_attr[edge]
  • E501 Line too long (106 > 79 characters)
216 adj_feat[1 + node_feat_space.sum():, target, source] = convert(edge_feat, edge_feat_space)
217  
218 feat.append(adj_feat)
219  
220 feat = torch.stack(feat)
221 return feat, y, metadata