⬅ datasets/rotated_mnist_dataset.py source

1 import os
2 import torch
3 import numpy as np
4 from gds.datasets.gds_dataset import GDSDataset
5 from ogb.graphproppred import Evaluator
6 from torch_geometric.data import download_url, extract_zip
7 from torch_geometric.data.dataloader import Collater as PyGCollater
8 import torch_geometric
9 from .pyg_rotated_mnist_dataset import PyGRotatedMNISTDataset
10 from torch_geometric.utils import to_dense_adj
11  
12  
13 class RotatedMNISTDataset(GDSDataset):
14 _dataset_name = 'RotatedMNIST'
15 _versions_dict = {
16 '1.0': {
17 'download_url': None,
18 'compressed_size': None}}
19  
  • E501 Line too long (114 > 79 characters)
20 def __init__(self, version=None, root_dir='data', download=False, split_scheme='official', random_split=False,
21 subgraph=False, **dataset_kwargs):
22 self._version = version
23 # internally call ogb package
  • E501 Line too long (85 > 79 characters)
24 self.ogb_dataset = PyGRotatedMNISTDataset(name='RotatedMNIST', root=root_dir)
25  
26 # set variables
27 self._data_dir = self.ogb_dataset.root
28 if split_scheme == 'official':
29 split_scheme = 'angle'
30 self._split_scheme = split_scheme
  • E501 Line too long (138 > 79 characters)
31 self._y_type = 'float' # although the task is binary classification, the prediction target contains nan value, thus we need float
32 self._y_size = self.ogb_dataset.num_tasks
33 self._n_classes = self.ogb_dataset.__num_classes__
34  
35 self._split_array = torch.zeros(len(self.ogb_dataset)).long()
36  
37 self._y_array = self.ogb_dataset.data.y
38 self._metadata_fields = ['angle', 'y']
39  
  • E265 Block comment should start with '# '
  • E501 Line too long (86 > 79 characters)
40 #https://www.dropbox.com/s/zulrcyh846w9maw/RotatedMNIST_group_expired.zip?dl=1
  • E501 Line too long (93 > 79 characters)
41 metadata_file_path = os.path.join(self.ogb_dataset.raw_dir, 'RotatedMNIST_group.npy')
42 if not os.path.exists(metadata_file_path):
43 metadata_zip_file_path = download_url(
  • E501 Line too long (114 > 79 characters)
44 'https://www.dropbox.com/s/xv6bx7ihqmeuv80/RotatedMNIST_group.zip?dl=1', self.ogb_dataset.raw_dir)
45 extract_zip(metadata_zip_file_path, self.ogb_dataset.raw_dir)
46 os.unlink(metadata_zip_file_path)
47  
  • E501 Line too long (103 > 79 characters)
48 self._metadata_array_wo_y = torch.from_numpy(np.load(metadata_file_path)).reshape(-1, 1).long()
49 self._metadata_array = torch.cat((self._metadata_array_wo_y,
  • E501 Line too long (94 > 79 characters)
50 torch.unsqueeze(self.ogb_dataset.data.y, dim=1)), 1)
51  
52 np.random.seed(0)
53 dataset_size = len(self.ogb_dataset)
54 if random_split:
55 random_index = np.random.permutation(dataset_size)
56 train_split_idx = random_index[:int(4 / 6 * dataset_size)]
  • E501 Line too long (93 > 79 characters)
57 val_split_idx = random_index[int(4 / 6 * dataset_size):int(5 / 6 * dataset_size)]
58 test_split_idx = random_index[int(5 / 6 * dataset_size):]
59 else:
60 # use the group info split data
61 train_val_group_idx, test_group_idx = range(0, 5), range(5, 6)
  • E501 Line too long (113 > 79 characters)
62 train_val_group_idx, test_group_idx = torch.tensor(train_val_group_idx), torch.tensor(test_group_idx)
63  
64 def split_idx(group_idx):
  • E501 Line too long (104 > 79 characters)
65 split_idx = torch.zeros(len(torch.squeeze(self._metadata_array_wo_y)), dtype=torch.bool)
66 for idx in group_idx:
  • E501 Line too long (82 > 79 characters)
67 split_idx += (torch.squeeze(self._metadata_array_wo_y) == idx)
68 return split_idx
69  
  • E501 Line too long (107 > 79 characters)
70 train_val_split_idx, test_split_idx = split_idx(train_val_group_idx), split_idx(test_group_idx)
71  
  • E501 Line too long (100 > 79 characters)
72 train_val_split_idx = torch.arange(dataset_size, dtype=torch.int64)[train_val_split_idx]
73 train_val_sets_size = len(train_val_split_idx)
74 random_index = np.random.permutation(train_val_sets_size)
  • E501 Line too long (98 > 79 characters)
75 train_split_idx = train_val_split_idx[random_index[:int(4 / 5 * train_val_sets_size)]]
  • E501 Line too long (96 > 79 characters)
76 val_split_idx = train_val_split_idx[random_index[int(4 / 5 * train_val_sets_size):]]
  • W293 Blank line contains whitespace
77
78  
  • E303 Too many blank lines (2)
79 self._split_array[train_split_idx] = 0
80 self._split_array[val_split_idx] = 1
81 self._split_array[test_split_idx] = 2
82  
83 if dataset_kwargs['model'] == '3wlgnn':
84 self._collate = self.collate_dense
85 else:
86 if torch_geometric.__version__ >= '1.7.0':
87 self._collate = PyGCollater(follow_batch=[], exclude_keys=[])
88 else:
89 self._collate = PyGCollater(follow_batch=[])
90  
91 self._metric = Evaluator('ogbg-ppa')
92  
93  
  • E303 Too many blank lines (2)
94 # GSN
95 self.subgraph = subgraph
96 if self.subgraph:
97 self.id_type = dataset_kwargs['gsn_id_type']
98 self.k = dataset_kwargs['gsn_k']
99 from gds.datasets.gsn.gsn_data_prep import GSN
  • E501 Line too long (114 > 79 characters)
100 subgraph = GSN(dataset_name='RotatedMNIST', dataset_group='MNIST', induced=True, id_type=self.id_type,
101 k=self.k)
  • E501 Line too long (116 > 79 characters)
102 self.graphs_ptg, self.encoder_ids, self.d_id, self.d_degree = subgraph.preprocess(self.ogb_dataset.root)
103  
104 if self.graphs_ptg[0].x.dim() == 1:
105 self.num_features = 1
106 else:
107 self.num_features = self.graphs_ptg[0].num_features
108  
109 if hasattr(self.graphs_ptg[0], 'edge_features'):
110 if self.graphs_ptg[0].edge_features.dim() == 1:
111 self.num_edge_features = 1
112 else:
  • E501 Line too long (86 > 79 characters)
113 self.num_edge_features = self.graphs_ptg[0].edge_features.shape[1]
114 else:
115 self.num_edge_features = None
116  
117 self.d_in_node_encoder = [self.num_features]
118 self.d_in_edge_encoder = [self.num_edge_features]
119  
120  
  • E303 Too many blank lines (2)
121 super().__init__(root_dir, download, split_scheme)
122  
123 def get_input(self, idx):
124 if self.subgraph:
125 return self.graphs_ptg[int(idx)]
126 else:
127 return self.ogb_dataset[int(idx)]
128  
129 def eval(self, y_pred, y_true, metadata, prediction_fn=None):
130 """
131 Computes all evaluation metrics.
132 Args:
133 - y_pred (FloatTensor): Binary logits from a model
134 - y_true (LongTensor): Ground-truth labels
135 - metadata (Tensor): Metadata
  • E501 Line too long (91 > 79 characters)
136 - prediction_fn (function): A function that turns y_pred into predicted labels.
  • E501 Line too long (106 > 79 characters)
137 Only None is supported because OGB Evaluators accept binary logits
138 Output:
139 - results (dictionary): Dictionary of evaluation metrics
140 - results_str (str): String summarizing the evaluation metrics
141 """
  • E501 Line too long (121 > 79 characters)
142 assert prediction_fn is None, "OGBPCBADataset.eval() does not support prediction_fn. Only binary logits accepted"
143 y_true = y_true.view(-1, 1)
144 y_pred = torch.argmax(y_pred.detach(), dim=1).view(-1, 1)
145 input_dict = {"y_true": y_true, "y_pred": y_pred}
146 results = self._metric.eval(input_dict)
147  
148 return results, f"Accuracy: {results['acc']:.3f}\n"
149  
150 # prepare dense tensors for GNNs using them; such as RingGNN, 3WLGNN
151 def collate_dense(self, samples):
152 def _sym_normalize_adj(adjacency):
153 deg = torch.sum(adjacency, dim=0) # .squeeze()
  • E501 Line too long (89 > 79 characters)
154 deg_inv = torch.where(deg > 0, 1. / torch.sqrt(deg), torch.zeros(deg.size()))
155 deg_inv = torch.diag(deg_inv)
156 return torch.mm(deg_inv, torch.mm(adjacency, deg_inv))
157  
158 # The input samples is a list of pairs (graph, label).
159 graph_list, y_list, metadata_list = map(list, zip(*samples))
160 y, metadata = torch.tensor(y_list), torch.stack(metadata_list)
161  
162 x_node_feat = []
163 for graph in graph_list:
  • E501 Line too long (109 > 79 characters)
164 adj = _sym_normalize_adj(to_dense_adj(graph.edge_index, max_num_nodes=graph.x.size(0)).squeeze())
165 zero_adj = torch.zeros_like(adj)
166 in_dim = graph.x.shape[1]
167  
168 # use node feats to prepare adj
169 adj_node_feat = torch.stack([zero_adj for _ in range(in_dim)])
170 adj_node_feat = torch.cat([adj.unsqueeze(0), adj_node_feat], dim=0)
171  
172 for node, node_feat in enumerate(graph.x):
173 adj_node_feat[1:, node, node] = node_feat
174  
175 x_node_feat.append(adj_node_feat)
176  
177 x_node_feat = torch.stack(x_node_feat)
178 return x_node_feat, y, metadata