⬅ datasets/sbm_environment_dataset.py source

1 import os
2 import torch
3 import numpy as np
4 from gds.datasets.gds_dataset import GDSDataset
5 from ogb.graphproppred import Evaluator
6 from torch_geometric.data import download_url, extract_zip
7 from torch_geometric.data.dataloader import Collater as PyGCollater
8 import torch_geometric
9 from .pyg_sbm_environment_dataset import PyGSBMEnvironmentDataset
10 from torch_geometric.utils import to_dense_adj
11 import torch.nn.functional as F
12  
13  
14 class SBMEnvironmentDataset(GDSDataset):
15 _dataset_name = 'SBM-Environment'
16 _versions_dict = {
17 '1.0': {
18 'download_url': None,
19 'compressed_size': None}}
20  
  • E501 Line too long (114 > 79 characters)
21 def __init__(self, version=None, root_dir='data', download=False, split_scheme='official', random_split=False,
22 **dataset_kwargs):
23 self._version = version
24 # internally call ogb package
  • E501 Line too long (90 > 79 characters)
25 self.ogb_dataset = PyGSBMEnvironmentDataset(name='SBM-Environment', root=root_dir)
26  
27 # set variables
28 self._data_dir = self.ogb_dataset.root
29 if split_scheme == 'official':
30 split_scheme = 'composition'
31 self._split_scheme = split_scheme
  • E501 Line too long (138 > 79 characters)
32 self._y_type = 'float' # although the task is binary classification, the prediction target contains nan value, thus we need float
33 self._y_size = self.ogb_dataset.num_tasks
34 self._n_classes = self.ogb_dataset.__num_classes__
35  
36 self._split_array = torch.zeros(len(self.ogb_dataset)).long()
37  
38 self._y_array = self.ogb_dataset.data.y
39 self._metadata_fields = ['composition', 'y']
40  
  • E501 Line too long (96 > 79 characters)
41 metadata_file_path = os.path.join(self.ogb_dataset.raw_dir, 'SBM-Environment_group.npy')
42 if not os.path.exists(metadata_file_path):
43 metadata_zip_file_path = download_url(
  • E501 Line too long (117 > 79 characters)
44 'https://www.dropbox.com/s/5xjd13f2qfiyku5/SBM-Environment_group.zip?dl=1', self.ogb_dataset.raw_dir)
45 extract_zip(metadata_zip_file_path, self.ogb_dataset.raw_dir)
46 os.unlink(metadata_zip_file_path)
  • E501 Line too long (103 > 79 characters)
47 self._metadata_array_wo_y = torch.from_numpy(np.load(metadata_file_path)).reshape(-1, 1).long()
48 self._metadata_array = torch.cat((self._metadata_array_wo_y,
  • E501 Line too long (94 > 79 characters)
49 torch.unsqueeze(self.ogb_dataset.data.y, dim=1)), 1)
50  
51 np.random.seed(0)
52 dataset_size = len(self.ogb_dataset)
53 if random_split:
54 random_index = np.random.permutation(dataset_size)
55 train_split_idx = random_index[:int(2 / 4 * dataset_size)]
  • E501 Line too long (93 > 79 characters)
56 val_split_idx = random_index[int(2 / 4 * dataset_size):int(3 / 4 * dataset_size)]
57 test_split_idx = random_index[int(3 / 4 * dataset_size):]
58 else:
59 # use the group info split data
60 train_val_group_idx, test_group_idx = range(0, 3), range(3, 4)
  • E501 Line too long (113 > 79 characters)
61 train_val_group_idx, test_group_idx = torch.tensor(train_val_group_idx), torch.tensor(test_group_idx)
62  
63 def split_idx(group_idx):
  • E501 Line too long (104 > 79 characters)
64 split_idx = torch.zeros(len(torch.squeeze(self._metadata_array_wo_y)), dtype=torch.bool)
65 for idx in group_idx:
  • E501 Line too long (82 > 79 characters)
66 split_idx += (torch.squeeze(self._metadata_array_wo_y) == idx)
67 return split_idx
68  
  • E501 Line too long (107 > 79 characters)
69 train_val_split_idx, test_split_idx = split_idx(train_val_group_idx), split_idx(test_group_idx)
70  
  • E501 Line too long (100 > 79 characters)
71 train_val_split_idx = torch.arange(dataset_size, dtype=torch.int64)[train_val_split_idx]
72 train_val_sets_size = len(train_val_split_idx)
73 random_index = np.random.permutation(train_val_sets_size)
  • E501 Line too long (98 > 79 characters)
74 train_split_idx = train_val_split_idx[random_index[:int(2 / 3 * train_val_sets_size)]]
  • E501 Line too long (96 > 79 characters)
75 val_split_idx = train_val_split_idx[random_index[int(2 / 3 * train_val_sets_size):]]
76  
77 self._split_array[train_split_idx] = 0
78 self._split_array[val_split_idx] = 1
79 self._split_array[test_split_idx] = 2
80  
81 if dataset_kwargs['model'] == '3wlgnn':
82 self._collate = self.collate_dense
83 else:
84 if torch_geometric.__version__ >= '1.7.0':
85 self._collate = PyGCollater(follow_batch=[], exclude_keys=[])
86 else:
87 self._collate = PyGCollater(follow_batch=[])
88  
89 self._metric = Evaluator('ogbg-ppa')
90  
91 super().__init__(root_dir, download, split_scheme)
92  
93 def get_input(self, idx):
94 return self.ogb_dataset[int(idx)]
95  
96 def eval(self, y_pred, y_true, metadata, prediction_fn=None):
97 """
98 Computes all evaluation metrics.
99 Args:
100 - y_pred (FloatTensor): Binary logits from a model
101 - y_true (LongTensor): Ground-truth labels
102 - metadata (Tensor): Metadata
  • E501 Line too long (91 > 79 characters)
103 - prediction_fn (function): A function that turns y_pred into predicted labels.
  • E501 Line too long (106 > 79 characters)
104 Only None is supported because OGB Evaluators accept binary logits
105 Output:
106 - results (dictionary): Dictionary of evaluation metrics
107 - results_str (str): String summarizing the evaluation metrics
108 """
  • E501 Line too long (121 > 79 characters)
109 assert prediction_fn is None, "OGBPCBADataset.eval() does not support prediction_fn. Only binary logits accepted"
110 y_true = y_true.view(-1, 1)
111 y_pred = torch.argmax(y_pred.detach(), dim=1).view(-1, 1)
112 input_dict = {"y_true": y_true, "y_pred": y_pred}
113 results = self._metric.eval(input_dict)
114  
115 return results, f"Accuracy: {results['acc']:.3f}\n"
116  
117 # prepare dense tensors for GNNs using them; such as RingGNN, 3WLGNN
118 def collate_dense(self, samples):
119 def _sym_normalize_adj(adjacency):
120 deg = torch.sum(adjacency, dim=0) # .squeeze()
  • E501 Line too long (89 > 79 characters)
121 deg_inv = torch.where(deg > 0, 1. / torch.sqrt(deg), torch.zeros(deg.size()))
122 deg_inv = torch.diag(deg_inv)
123 return torch.mm(deg_inv, torch.mm(adjacency, deg_inv))
124  
125 # The input samples is a list of pairs (graph, label).
126 node_feat_space = torch.tensor([8])
127  
128 graph_list, y_list, metadata_list = map(list, zip(*samples))
129 y, metadata = torch.tensor(y_list), torch.stack(metadata_list)
130  
131 feat = []
132 for graph in graph_list:
  • E501 Line too long (109 > 79 characters)
133 adj = _sym_normalize_adj(to_dense_adj(graph.edge_index, max_num_nodes=graph.x.size(0)).squeeze())
134 zero_adj = torch.zeros_like(adj)
135 in_dim = node_feat_space.sum()
136  
137 # use node feats to prepare adj
138 adj_feat = torch.stack([zero_adj for _ in range(in_dim)])
139 adj_feat = torch.cat([adj.unsqueeze(0), adj_feat], dim=0)
140  
141 def convert(feature, space):
142 out = []
143 for i, label in enumerate(feature):
144 out.append(F.one_hot(label, space[i]))
145 return torch.cat(out)
146  
147 for node, node_feat in enumerate(graph.x):
  • E501 Line too long (114 > 79 characters)
148 adj_feat[1:1+node_feat_space.sum(), node, node] = convert(node_feat.unsqueeze(0), node_feat_space)
149  
150 feat.append(adj_feat)
151  
152 feat = torch.stack(feat)
153 return feat, y, metadata