⬅ datasets/upfd_dataset.py source

1 import os
2 import torch
3 import numpy as np
4 from gds.datasets.gds_dataset import GDSDataset
5 from ogb.graphproppred import Evaluator
6 from torch_geometric.data import download_url, extract_zip
7 from torch_geometric.data.dataloader import Collater as PyGCollater
8 import torch_geometric
9 from .pyg_upfd_dataset import PyGUPFDDataset
10 from torch_geometric.utils import to_dense_adj
11 import torch.nn.functional as F
12  
  • E302 Expected 2 blank lines, found 1
13 class UPFDDataset(GDSDataset):
14 _dataset_name = 'UPFD'
15 _versions_dict = {
16 '1.0': {
17 'download_url': None,
18 'compressed_size': None}}
19  
  • E501 Line too long (114 > 79 characters)
20 def __init__(self, version=None, root_dir='data', download=False, split_scheme='official', random_split=False,
21 **dataset_kwargs):
22 self._version = version
23 # internally call ogb package
24 self.ogb_dataset = PyGUPFDDataset(root_dir, 'UPFD', 'profile')
25  
26 # set variables
27 self._data_dir = self.ogb_dataset.root
28 if split_scheme == 'official':
29 split_scheme = 'size'
30 self._split_scheme = split_scheme
  • E501 Line too long (138 > 79 characters)
31 self._y_type = 'float' # although the task is binary classification, the prediction target contains nan value, thus we need float
32 self._y_size = self.ogb_dataset.num_tasks
33 self._n_classes = self.ogb_dataset.__num_classes__
34  
35 self._split_array = torch.zeros(len(self.ogb_dataset)).long()
36  
37 self._y_array = self.ogb_dataset.data.y.unsqueeze(-1)
38 self._metadata_fields = ['size', 'y']
39  
  • E501 Line too long (85 > 79 characters)
40 metadata_file_path = os.path.join(self.ogb_dataset.raw_dir, 'UPFD_group.npy')
41 if not os.path.exists(metadata_file_path):
42 metadata_zip_file_path = download_url(
  • E501 Line too long (106 > 79 characters)
43 'https://www.dropbox.com/s/bs5u83x2vjf7t0f/UPFD_group.zip?dl=1', self.ogb_dataset.raw_dir)
44 extract_zip(metadata_zip_file_path, self.ogb_dataset.raw_dir)
45 os.unlink(metadata_zip_file_path)
  • E501 Line too long (103 > 79 characters)
46 self._metadata_array_wo_y = torch.from_numpy(np.load(metadata_file_path)).reshape(-1, 1).long()
47 self._metadata_array = torch.cat((self._metadata_array_wo_y,
  • E501 Line too long (94 > 79 characters)
48 torch.unsqueeze(self.ogb_dataset.data.y, dim=1)), 1)
49  
50 np.random.seed(0)
51 dataset_size = len(self.ogb_dataset)
52 if random_split:
53 random_index = np.random.permutation(dataset_size)
54 train_split_idx = random_index[:int(6 / 10 * dataset_size)]
  • E501 Line too long (95 > 79 characters)
55 val_split_idx = random_index[int(6 / 10 * dataset_size):int(8 / 10 * dataset_size)]
56 test_split_idx = random_index[int(8 / 10 * dataset_size):]
57 else:
58 # use the group info split data
59 train_val_group_idx, test_group_idx = range(0, 8), range(8, 10)
  • E501 Line too long (113 > 79 characters)
60 train_val_group_idx, test_group_idx = torch.tensor(train_val_group_idx), torch.tensor(test_group_idx)
61  
62 def split_idx(group_idx):
  • E501 Line too long (104 > 79 characters)
63 split_idx = torch.zeros(len(torch.squeeze(self._metadata_array_wo_y)), dtype=torch.bool)
64 for idx in group_idx:
  • E501 Line too long (82 > 79 characters)
65 split_idx += (torch.squeeze(self._metadata_array_wo_y) == idx)
66 return split_idx
67  
  • E501 Line too long (107 > 79 characters)
68 train_val_split_idx, test_split_idx = split_idx(train_val_group_idx), split_idx(test_group_idx)
69  
  • E501 Line too long (100 > 79 characters)
70 train_val_split_idx = torch.arange(dataset_size, dtype=torch.int64)[train_val_split_idx]
71 train_val_sets_size = len(train_val_split_idx)
72 random_index = np.random.permutation(train_val_sets_size)
  • E501 Line too long (98 > 79 characters)
73 train_split_idx = train_val_split_idx[random_index[:int(6 / 8 * train_val_sets_size)]]
  • E501 Line too long (96 > 79 characters)
74 val_split_idx = train_val_split_idx[random_index[int(6 / 8 * train_val_sets_size):]]
  • W293 Blank line contains whitespace
75
76 self._split_array[train_split_idx] = 0
77 self._split_array[val_split_idx] = 1
78 self._split_array[test_split_idx] = 2
  • W293 Blank line contains whitespace
79
80  
  • E303 Too many blank lines (2)
81 if dataset_kwargs['model'] == '3wlgnn':
82 self._collate = self.collate_dense
83 else:
84 if torch_geometric.__version__ >= '1.7.0':
85 self._collate = PyGCollater(follow_batch=[], exclude_keys=[])
86 else:
87 self._collate = PyGCollater(follow_batch=[])
88 self._metric = Evaluator('ogbg-molhiv')
89 super().__init__(root_dir, download, split_scheme)
90  
91 def get_input(self, idx):
92 return self.ogb_dataset[int(idx)]
93  
94 def eval(self, y_pred, y_true, metadata, prediction_fn=None):
95 """
96 Computes all evaluation metrics.
97 Args:
98 - y_pred (FloatTensor): Binary logits from a model
99 - y_true (LongTensor): Ground-truth labels
100 - metadata (Tensor): Metadata
  • E501 Line too long (91 > 79 characters)
101 - prediction_fn (function): A function that turns y_pred into predicted labels.
  • E501 Line too long (106 > 79 characters)
102 Only None is supported because OGB Evaluators accept binary logits
103 Output:
104 - results (dictionary): Dictionary of evaluation metrics
105 - results_str (str): String summarizing the evaluation metrics
106 """
  • W291 Trailing whitespace
107 # y_true = y_true.view(-1, 1)
108 # y_pred = (y_pred > 0).long().view(-1, 1)
109 # input_dict = {"y_true": y_true, "y_pred": y_pred}
  • W291 Trailing whitespace
110 # acc = (y_pred == y_true).sum() / len(y_pred)
111 # results = {'acc': np.float(acc)}
112  
113 # return results, f"Accuracy: {acc:.3f}\n"
114 assert prediction_fn is None
115 input_dict = {"y_true": y_true, "y_pred": y_pred}
116 results = self._metric.eval(input_dict)
117  
118 return results, f"ROCAUC: {results['rocauc']:.3f}\n"
119  
120 # prepare dense tensors for GNNs using them; such as RingGNN, 3WLGNN
121 def collate_dense(self, samples):
122 def _sym_normalize_adj(adjacency):
123 deg = torch.sum(adjacency, dim=0) # .squeeze()
  • E501 Line too long (89 > 79 characters)
124 deg_inv = torch.where(deg > 0, 1. / torch.sqrt(deg), torch.zeros(deg.size()))
125 deg_inv = torch.diag(deg_inv)
126 return torch.mm(deg_inv, torch.mm(adjacency, deg_inv))
127  
128 # The input samples is a list of pairs (graph, label).
129 node_feat_space = torch.tensor([8])
130  
131 graph_list, y_list, metadata_list = map(list, zip(*samples))
132 y, metadata = torch.tensor(y_list), torch.stack(metadata_list)
133  
134 # insert size one at dim 0 because this dataset's y is 1d
135 y = y.unsqueeze(0)
136  
137 feat = []
138 for graph in graph_list:
  • E501 Line too long (109 > 79 characters)
139 adj = _sym_normalize_adj(to_dense_adj(graph.edge_index, max_num_nodes=graph.x.size(0)).squeeze())
140 zero_adj = torch.zeros_like(adj)
141 in_dim = node_feat_space.sum()
142  
143 # use node feats to prepare adj
144 adj_feat = torch.stack([zero_adj for _ in range(in_dim)])
145 adj_feat = torch.cat([adj.unsqueeze(0), adj_feat], dim=0)
146  
147 def convert(feature, space):
148 out = []
149 for i, label in enumerate(feature):
150 out.append(F.one_hot(label, space[i]))
151 return torch.cat(out)
152  
153 for node, node_feat in enumerate(graph.x):
  • E501 Line too long (116 > 79 characters)
154 adj_feat[1:1 + node_feat_space.sum(), node, node] = convert(node_feat.unsqueeze(0), node_feat_space)
155  
156 feat.append(adj_feat)
157  
158 feat = torch.stack(feat)
159  
160  
  • E303 Too many blank lines (2)
161 return feat, y, metadata