⬅ datasets/colored_mnist_dataset.py source

1 import os
2 import torch
3 import numpy as np
4 from gds.datasets.gds_dataset import GDSDataset
5 from ogb.graphproppred import Evaluator
6 from torch_geometric.data import download_url, extract_zip
7 from torch_geometric.data.dataloader import Collater as PyGCollater
8 import torch_geometric
9 from .pyg_colored_mnist_dataset import PyGColoredMNISTDataset
10 from torch_geometric.utils import to_dense_adj
  • F401 'torch.nn.functional as F' imported but unused
11 import torch.nn.functional as F
12  
  • E302 Expected 2 blank lines, found 1
13 class ColoredMNISTDataset(GDSDataset):
14 _dataset_name = 'ColoredMNIST'
15 _versions_dict = {
16 '1.0': {
17 'download_url': None,
18 'compressed_size': None}
19 }
20  
  • E501 Line too long (114 > 79 characters)
21 def __init__(self, version=None, root_dir='data', download=False, split_scheme='official', random_split=False,
22 **dataset_kwargs):
23 self._version = version
24 # internally call ogb package
  • E501 Line too long (85 > 79 characters)
25 self.ogb_dataset = PyGColoredMNISTDataset(name='ColoredMNIST', root=root_dir)
26  
27 # set variables
28 self._data_dir = self.ogb_dataset.root
29 if split_scheme == 'official':
30 split_scheme = 'color'
31 self._split_scheme = split_scheme
  • E501 Line too long (138 > 79 characters)
32 self._y_type = 'float' # although the task is binary classification, the prediction target contains nan value, thus we need float
33 self._y_size = self.ogb_dataset.num_tasks
34 self._n_classes = self.ogb_dataset.__num_classes__
35  
36 self._split_array = torch.zeros(len(self.ogb_dataset)).long()
37  
38 self._y_array = self.ogb_dataset.data.y.unsqueeze(-1)
39 self._metadata_fields = ['color', 'y']
40  
41 # https://www.dropbox.com/s/envr0eslhtssy2y/ColoredMNIST_group_expired.zip?dl=1
  • E501 Line too long (93 > 79 characters)
42 metadata_file_path = os.path.join(self.ogb_dataset.raw_dir, 'ColoredMNIST_group.npy')
43 if not os.path.exists(metadata_file_path):
44 metadata_zip_file_path = download_url(
  • E501 Line too long (114 > 79 characters)
45 'https://www.dropbox.com/s/ax1ek3yc6n1q739/ColoredMNIST_group.zip?dl=1', self.ogb_dataset.raw_dir)
46 extract_zip(metadata_zip_file_path, self.ogb_dataset.raw_dir)
47 os.unlink(metadata_zip_file_path)
48  
  • E501 Line too long (103 > 79 characters)
49 self._metadata_array_wo_y = torch.from_numpy(np.load(metadata_file_path)).reshape(-1, 1).long()
50 self._metadata_array = torch.cat((self._metadata_array_wo_y,
  • E501 Line too long (94 > 79 characters)
51 torch.unsqueeze(self.ogb_dataset.data.y, dim=1)), 1)
52  
53 np.random.seed(0)
54 dataset_size = len(self.ogb_dataset)
55 if random_split:
56 random_index = np.random.permutation(dataset_size)
57 train_split_idx = random_index[:int(3 / 6 * dataset_size)]
  • E501 Line too long (93 > 79 characters)
58 val_split_idx = random_index[int(3 / 6 * dataset_size):int(4 / 6 * dataset_size)]
59 test_split_idx = random_index[int(4 / 6 * dataset_size):]
60 else:
61 # use the group info split data
62 train_val_group_idx, test_group_idx = range(0, 2), range(2, 3)
  • E501 Line too long (113 > 79 characters)
63 train_val_group_idx, test_group_idx = torch.tensor(train_val_group_idx), torch.tensor(test_group_idx)
64  
65 def split_idx(group_idx):
  • E501 Line too long (104 > 79 characters)
66 split_idx = torch.zeros(len(torch.squeeze(self._metadata_array_wo_y)), dtype=torch.bool)
67 for idx in group_idx:
  • E501 Line too long (82 > 79 characters)
68 split_idx += (torch.squeeze(self._metadata_array_wo_y) == idx)
69 return split_idx
70  
  • E501 Line too long (107 > 79 characters)
71 train_val_split_idx, test_split_idx = split_idx(train_val_group_idx), split_idx(test_group_idx)
72  
  • E501 Line too long (100 > 79 characters)
73 train_val_split_idx = torch.arange(dataset_size, dtype=torch.int64)[train_val_split_idx]
74 train_val_sets_size = len(train_val_split_idx)
75 random_index = np.random.permutation(train_val_sets_size)
  • E501 Line too long (98 > 79 characters)
76 train_split_idx = train_val_split_idx[random_index[:int(3 / 4 * train_val_sets_size)]]
  • E501 Line too long (96 > 79 characters)
77 val_split_idx = train_val_split_idx[random_index[int(3 / 4 * train_val_sets_size):]]
78  
79 self._split_array[train_split_idx] = 0
80 self._split_array[val_split_idx] = 1
81 self._split_array[test_split_idx] = 2
82  
83 if dataset_kwargs['model'] == '3wlgnn':
84 self._collate = self.collate_dense
85 else:
86 if torch_geometric.__version__ >= '1.7.0':
87 self._collate = PyGCollater(follow_batch=[], exclude_keys=[])
88 else:
89 self._collate = PyGCollater(follow_batch=[])
90  
91 self._metric = Evaluator('ogbg-molhiv')
92 super().__init__(root_dir, download, split_scheme)
93  
94 def get_input(self, idx):
95 return self.ogb_dataset[int(idx)]
96  
97 def eval(self, y_pred, y_true, metadata, prediction_fn=None):
98 """
99 Computes all evaluation metrics.
100 Args:
101 - y_pred (FloatTensor): Binary logits from a model
102 - y_true (LongTensor): Ground-truth labels
103 - metadata (Tensor): Metadata
  • E501 Line too long (91 > 79 characters)
104 - prediction_fn (function): A function that turns y_pred into predicted labels.
  • E501 Line too long (106 > 79 characters)
105 Only None is supported because OGB Evaluators accept binary logits
106 Output:
107 - results (dictionary): Dictionary of evaluation metrics
108 - results_str (str): String summarizing the evaluation metrics
109 """
110 # y_true = y_true.view(-1, 1)
111 # y_pred = (y_pred > 0).long().view(-1, 1)
112 # input_dict = {"y_true": y_true, "y_pred": y_pred}
113 # acc = (y_pred == y_true).sum() / len(y_pred)
114 # results = {'acc': np.float(acc)}
115  
116 # return results, f"Accuracy: {acc:.3f}\n"
117 assert prediction_fn is None
118 input_dict = {"y_true": y_true, "y_pred": y_pred}
119 results = self._metric.eval(input_dict)
120  
121 return results, f"ROCAUC: {results['rocauc']:.3f}\n"
122  
123 # prepare dense tensors for GNNs using them; such as RingGNN, 3WLGNN
124 def collate_dense(self, samples):
125 # The input samples is a list of pairs (graph, label).
126  
127 graph_list, y_list, metadata_list = map(list, zip(*samples))
128 y, metadata = torch.tensor(y_list), torch.stack(metadata_list)
129  
130 # insert size one at dim 0 because this dataset's y is 1d
131 y = y.unsqueeze(0)
132  
133 x_node_feat = []
134 for graph in graph_list:
  • E501 Line too long (83 > 79 characters)
135 adj = self._sym_normalize_adj(to_dense_adj(graph.edge_index).squeeze())
136 zero_adj = torch.zeros_like(adj)
137 in_dim = graph.x.shape[1]
138 # in_dim = 10
139  
140 # use node feats to prepare adj
141 adj_node_feat = torch.stack([zero_adj for _ in range(in_dim)])
142 adj_node_feat = torch.cat([adj.unsqueeze(0), adj_node_feat], dim=0)
143  
144 for node, node_feat in enumerate(graph.x):
145 adj_node_feat[1:, node, node] = node_feat
146 # adj_node_feat[1:3, node, node] = node_feat[:2]
  • E501 Line too long (83 > 79 characters)
147 # adj_node_feat[3:, node, node] = F.one_hot(node_feat[2].long(), 8)
148  
149 x_node_feat.append(adj_node_feat)
150 x_node_feat = torch.stack(x_node_feat)
151  
152 return x_node_feat, y, metadata
153  
154 def _sym_normalize_adj(self, adj):
155 deg = torch.sum(adj, dim=0) # .squeeze()
  • E501 Line too long (85 > 79 characters)
156 deg_inv = torch.where(deg > 0, 1. / torch.sqrt(deg), torch.zeros(deg.size()))
157 deg_inv = torch.diag(deg_inv)
158 return torch.mm(deg_inv, torch.mm(adj, deg_inv))