1 import os
2 import os.path as osp
3 import torch
-
E501
Line too long (81 > 79 characters)
4 from torch_geometric.data import InMemoryDataset, download_url, extract_zip, Data
5 from typing import Optional, Callable, List
-
F401
'numpy as np' imported but unused
6 import numpy as np
7
8
9 class PyGColoredMNISTDataset(InMemoryDataset):
10 names = ['ColoredMNIST', 'ColoredMNIST_expired']
11 urls = {
-
E501
Line too long (90 > 79 characters)
12 'ColoredMNIST': 'https://www.dropbox.com/s/3zfi43erynkmi0i/ColoredMNIST.zip?dl=1',
-
E501
Line too long (105 > 79 characters)
13 'ColoredMNIST_expired': 'https://www.dropbox.com/s/bz0mkww2ivu8rwn/ColoredMNIST_expired.zip?dl=1'
14 }
15
16 def __init__(self, root: str, name: str,
17 transform: Optional[Callable] = None,
18 pre_transform: Optional[Callable] = None,
19 pre_filter: Optional[Callable] = None):
20 self.name = name
21 assert self.name in self.names
22 if name == 'ColoredMNIST':
23 self.num_tasks = 1
24 self.__num_classes__ = 1
25 else:
26 raise NotImplementedError
27
28 super().__init__(root, transform, pre_transform, pre_filter)
29 self.data, self.slices = torch.load(self.processed_paths[0])
30
31 @property
32 def raw_dir(self) -> str:
33 return osp.join(self.root, self.name, 'raw')
34
35 @property
36 def processed_dir(self) -> str:
37 return osp.join(self.root, self.name, 'processed')
38
39 @property
40 def raw_file_names(self) -> List[str]:
41 # raw file name must match the dataset name
42 return [f'{self.name}.pt']
43
44 @property
45 def processed_file_names(self) -> List[str]:
46 return [f'{self.name}_processed_mean_px_feat.pt']
47
48 def download(self):
49 path = download_url(self.urls[self.name], self.raw_dir)
50 extract_zip(path, self.raw_dir)
51 os.unlink(path)
52
53 def process(self):
54 inputs = torch.load(self.raw_paths[0])
55
56 # data_list = [Data(**data_dict) for data_dict in inputs]
57
58 data_list = []
-
E203
Whitespace before ':'
59 for data_dict in inputs :
60 data_dict['x'] = data_dict['x'][:, :2]
61 data_list.append(Data(**data_dict))
62
63 # total_num_nodes = 0
64 # for data_dict in inputs :
65 # total_num_nodes += data_dict['x'].shape[0]
66 #
67 # np.random.seed(0)
68 # padded_cate = np.random.randint(8, size=total_num_nodes)
69 #
70 # data_list = []
71 # ptr = 0
72 # for i, data_dict in enumerate(inputs) :
73 # num_node = data_dict['x'].shape[0]
74 # cate = torch.tensor(padded_cate[ptr:ptr+num_node]).view((-1,1))
75 # x = torch.cat((data_dict['x'][:, :2], cate), dim=1)
76 # data_dict['x'] = x
77 # data_list.append(Data(**data_dict))
78 # ptr += num_node
79 # assert ptr == total_num_nodes
80
81 if self.pre_filter is not None:
82 data_list = [d for d in data_list if self.pre_filter(d)]
83
84 if self.pre_transform is not None:
85 data_list = [self.pre_transform(d) for d in data_list]
86
87 torch.save(self.collate(data_list), self.processed_paths[0])