1 import os
            
            2 import numpy as np
            
            3  
            
            4 import pandas as pd
            
            5 import torch
            
            6 import torch_geometric
            
            7 from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
            
            8 from torch_geometric.data.dataloader import Collater as PyGCollater
            
            9  
            
            10 from gds.datasets.gds_dataset import GDSDataset
            
            11 from torch_geometric.utils import to_dense_adj
            
            12  
            
            13  
            
            14 def add_zeros(data):
            
            15     data.x = torch.zeros(data.num_nodes, dtype=torch.long)
            
            16     return data
            
            17  
            
            18  
            
            19 class OGBGPPADataset(GDSDataset):
            
            20     """
            
            21     The OGB-ppa dataset.
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (102 > 79 characters)
22     This dataset is directly adopted from Open Graph Benchmark, and originally curated by MoleculeNet. 
            23  
            
            24     Supported `split_scheme`:
            
            25         - 'official' or 'scaffold', which are equivalent
            
            26  
            
            27     Input (x):
            
            28         Molecular graphs represented as Pytorch Geometric data objects
            
            29  
            
            30     Label (y):
            
            31         y represents 128-class binary labels.
            
            32  
            
            33     Metadata:
            
            34         - scaffold
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (93 > 79 characters)
35             Each molecule is annotated with the scaffold ID that the molecule is assigned to. 
            36  
            
            37     Website:
            
            38         https://ogb.stanford.edu/docs/graphprop/#ogbg-mol
            
            39  
            
            40     Original publication:
            
            41         @article{hu2020ogb,
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (82 > 79 characters)
42             title={Open Graph Benchmark: Datasets for Machine Learning on Graphs}, 
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (112 > 79 characters)
43             author={W. {Hu}, M. {Fey}, M. {Zitnik}, Y. {Dong}, H. {Ren}, B. {Liu}, M. {Catasta}, J. {Leskovec}}, 
            44             journal={arXiv preprint arXiv:2005.00687},
            
            45             year={2020}
            
            46         }
            
            47  
            
            48         @article{wu2018moleculenet,
            
            49             title={MoleculeNet: a benchmark for molecular machine learning},
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (129 > 79 characters)
50             author={Z. {Wu}, B. {Ramsundar}, E. V {Feinberg}, J. {Gomes}, C. {Geniesse}, A. S {Pappu}, K. {Leswing}, V. {Pande}}, 
            51             journal={Chemical science},
            
            52             volume={9},
            
            53             number={2},
            
            54             pages={513--530},
            
            55             year={2018},
            
            56             publisher={Royal Society of Chemistry}
            
            57         }
            
            58  
            
            59     License:
            
            60         This dataset is distributed under the MIT license.
            
            61         https://github.com/snap-stanford/ogb/blob/master/LICENSE
            
            62     """
            
            63  
            
            64     _dataset_name = 'ogbg-ppa'
            
            65     _versions_dict = {
            
            66         '1.0': {
            
            67             'download_url': None,
            
            68             'compressed_size': None}}
            
            69  
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (114 > 79 characters)
70     def __init__(self, version=None, root_dir='data', download=False, split_scheme='official', random_split=False, 
            71                  subgraph=False, **dataset_kwargs):
            
            72         self._version = version
            
            73         if version is not None:
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (115 > 79 characters)
74             raise ValueError('Versioning for OGB-PPA is handled through the OGB package. Please set version=none.') 
            75         # internally call ogb package
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (103 > 79 characters)
76         self.ogb_dataset = PygGraphPropPredDataset(name='ogbg-ppa', root=root_dir, transform=add_zeros) 
            77  
            
            78         # set variables
            
            79         self._data_dir = self.ogb_dataset.root
            
            80         if split_scheme == 'official':
            
            81             split_scheme = 'scaffold'
            
            82         self._split_scheme = split_scheme
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (138 > 79 characters)
83         self._y_type = 'float'  # although the task is binary classification, the prediction target contains nan value, thus we need float 
            84         self._y_size = self.ogb_dataset.num_tasks
            
            85         self._n_classes = self.ogb_dataset.__num_classes__
            
            86  
            
            87         self._split_array = torch.zeros(len(self.ogb_dataset)).long()
            
            88         split_idx = self.ogb_dataset.get_idx_split()
            
            89  
            
            90         np.random.seed(0)
            
            91         dataset_size = len(self.ogb_dataset)
            
            92         if random_split:
            
            93             random_index = np.random.permutation(dataset_size)
            
            94             train_split_idx = random_index[:len(split_idx['train'])]
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (113 > 79 characters)
95             val_split_idx = random_index[len(split_idx['train']):len(split_idx['train'])+len(split_idx['valid'])] 
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (91 > 79 characters)
96             test_split_idx = random_index[len(split_idx['train'])+len(split_idx['valid']):] 
            
               
               
                  - 
                     
                        E203
                     
                     Whitespace before ':'
97         else : 
            98             train_split_idx = split_idx['train']
            
            99             val_split_idx = split_idx['valid']
            
            100             test_split_idx = split_idx['test']
            
            101  
            
            102         self._split_array[train_split_idx] = 0
            
            103         self._split_array[val_split_idx] = 1
            
            104         self._split_array[test_split_idx] = 2
            
            105  
            
            106  
            
            
               
               
                  - 
                     
                        E303
                     
                     Too many blank lines (2)
107         self._y_array = self.ogb_dataset.data.y.squeeze() 
            108  
            
            109         self._metadata_fields = ['species', 'y']
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (104 > 79 characters)
110         metadata_file_path = os.path.join(self.ogb_dataset.root, 'mapping', 'graphidx2speciesid.csv.gz') 
            111         df = pd.read_csv(metadata_file_path)
            
            112         df_np = df['species id'].to_numpy()
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (95 > 79 characters)
113         self._metadata_array_wo_y = torch.from_numpy(df_np - df_np.min()).reshape(-1, 1).long() 
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (97 > 79 characters)
114         self._metadata_array = torch.cat((self._metadata_array_wo_y, self.ogb_dataset.data.y), 1) 
            115  
            
            116         if dataset_kwargs['model'] == '3wlgnn':
            
            117             self._collate = self.collate_dense
            
            118         else:
            
            119             if torch_geometric.__version__ >= '1.7.0':
            
            120                 self._collate = PyGCollater(follow_batch=[], exclude_keys=[])
            
            121             else:
            
            122                 self._collate = PyGCollater(follow_batch=[])
            
            123  
            
            124         self._metric = Evaluator('ogbg-ppa')
            
            125  
            
            126  
            
            
               
               
                  - 
                     
                        E303
                     
                     Too many blank lines (2)
127         # GSN 
            128         self.subgraph = subgraph
            
            129         if self.subgraph:
            
            130             self.id_type = dataset_kwargs['gsn_id_type']
            
            131             self.k = dataset_kwargs['gsn_k']
            
            132             from gds.datasets.gsn.gsn_data_prep import GSN
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (108 > 79 characters)
133             subgraph = GSN(dataset_name='ogbg-ppa', dataset_group='ogb', induced=True, id_type=self.id_type, 
            134                            k=self.k)
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (116 > 79 characters)
135             self.graphs_ptg, self.encoder_ids, self.d_id, self.d_degree = subgraph.preprocess(self.ogb_dataset.root) 
            136  
            
            137             if self.graphs_ptg[0].x.dim() == 1:
            
            138                 self.num_features = 1
            
            139             else:
            
            140                 self.num_features = self.graphs_ptg[0].num_features
            
            141  
            
            142             if hasattr(self.graphs_ptg[0], 'edge_features'):
            
            143                 if self.graphs_ptg[0].edge_features.dim() == 1:
            
            144                     self.num_edge_features = 1
            
            145                 else:
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (86 > 79 characters)
146                     self.num_edge_features = self.graphs_ptg[0].edge_features.shape[1] 
            147             else:
            
            148                 self.num_edge_features = None
            
            149  
            
            150             self.d_in_node_encoder = [self.num_features]
            
            151             self.d_in_edge_encoder = [self.num_edge_features]
            
            152  
            
            153         super().__init__(root_dir, download, split_scheme)
            
            154  
            
            155     def get_input(self, idx):
            
            156         if self.subgraph:
            
            157             return self.graphs_ptg[int(idx)]
            
            158         else:
            
            159             return self.ogb_dataset[int(idx)]
            
            160  
            
            161     def eval(self, y_pred, y_true, metadata, prediction_fn=None):
            
            162         """
            
            163         Computes all evaluation metrics.
            
            164         Args:
            
            165             - y_pred (FloatTensor): Binary logits from a model
            
            166             - y_true (LongTensor): Ground-truth labels
            
            167             - metadata (Tensor): Metadata
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (91 > 79 characters)
168             - prediction_fn (function): A function that turns y_pred into predicted labels. 
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (106 > 79 characters)
169                                         Only None is supported because OGB Evaluators accept binary logits 
            170         Output:
            
            171             - results (dictionary): Dictionary of evaluation metrics
            
            172             - results_str (str): String summarizing the evaluation metrics
            
            173         """
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (120 > 79 characters)
174         assert prediction_fn is None, "OGBHIVDataset.eval() does not support prediction_fn. Only binary logits accepted" 
            175         y_true = y_true.view(-1, 1)
            
            176         y_pred = torch.argmax(y_pred.detach(), dim=1).view(-1, 1)
            
            177         input_dict = {"y_true": y_true, "y_pred": y_pred}
            
            178  
            
            179         results = self._metric.eval(input_dict)
            
            180  
            
            181         return results, f"Accuracy: {results['acc']:.3f}\n"
            
            182  
            
            183     # prepare dense tensors for GNNs using them; such as RingGNN, 3WLGNN
            
            184     def collate_dense(self, samples):
            
            185         def _sym_normalize_adj(adjacency):
            
            186             deg = torch.sum(adjacency, dim=0)  # .squeeze()
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (89 > 79 characters)
187             deg_inv = torch.where(deg > 0, 1. / torch.sqrt(deg), torch.zeros(deg.size())) 
            188             deg_inv = torch.diag(deg_inv)
            
            189             return torch.mm(deg_inv, torch.mm(adjacency, deg_inv))
            
            190  
            
            191         # The input samples is a list of pairs (graph, label).
            
            192         graph_list, y_list, metadata_list = map(list, zip(*samples))
            
            193         y, metadata = torch.tensor(y_list), torch.stack(metadata_list)
            
            194  
            
            195         x_edge_feat = []
            
            196         for graph in graph_list:
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (109 > 79 characters)
197             adj = _sym_normalize_adj(to_dense_adj(graph.edge_index, max_num_nodes=graph.x.size(0)).squeeze()) 
            198             zero_adj = torch.zeros_like(adj)
            
            199             in_dim = graph.edge_attr.shape[1]
            
            200  
            
            201             # use node feats to prepare adj
            
            202             adj_edge_feat = torch.stack([zero_adj for _ in range(in_dim)])
            
            203             adj_edge_feat = torch.cat([adj.unsqueeze(0), adj_edge_feat], dim=0)
            
            204  
            
            205             for edge in range(graph.edge_index.shape[1]):
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (85 > 79 characters)
206                 target, source = graph.edge_index[0][edge], graph.edge_index[1][edge] 
            207                 adj_edge_feat[1:, target, source] = graph.edge_attr[edge]
            
            208  
            
            209             x_edge_feat.append(adj_edge_feat)
            
            210  
            
            211         x_edge_feat = torch.stack(x_edge_feat)
            
            212         return x_edge_feat, y, metadata