1 import os
2 import os.path as osp
3 import torch
-
E501
Line too long (81 > 79 characters)
4 from torch_geometric.data import InMemoryDataset, download_url, extract_zip, Data
5 from typing import Optional, Callable, List
6
7
8 class PyGRotatedMNISTDataset(InMemoryDataset):
9 names = ['RotatedMNIST', 'RotatedMNIST_expired']
10 urls = {
-
E501
Line too long (90 > 79 characters)
11 'RotatedMNIST': 'https://www.dropbox.com/s/3tgk65bpjp8mvhj/RotatedMNIST.zip?dl=1',
-
E501
Line too long (105 > 79 characters)
12 'RotatedMNIST_expired': 'https://www.dropbox.com/s/5kybifusm8jexna/RotatedMNIST_expired.zip?dl=1'
13 }
14
15 def __init__(self, root: str, name: str,
16 transform: Optional[Callable] = None,
17 pre_transform: Optional[Callable] = None,
18 pre_filter: Optional[Callable] = None):
19 self.name = name
20 assert self.name in self.names
21 if name == 'RotatedMNIST':
22 self.num_tasks = 1
23 self.__num_classes__ = 10
24 else:
25 raise NotImplementedError
26
27 super().__init__(root, transform, pre_transform, pre_filter)
28 self.data, self.slices = torch.load(self.processed_paths[0])
29
30 @property
31 def raw_dir(self) -> str:
32 return osp.join(self.root, self.name, 'raw')
33
34 @property
35 def processed_dir(self) -> str:
36 return osp.join(self.root, self.name, 'processed')
37
38 @property
39 def raw_file_names(self) -> List[str]:
40 # raw file name must match the dataset name
41 return [f'{self.name}.pt']
42
43 @property
44 def processed_file_names(self) -> List[str]:
45 return [f'{self.name}_processed_mean_px_feat.pt']
46
47 def download(self):
48 path = download_url(self.urls[self.name], self.raw_dir)
49 extract_zip(path, self.raw_dir)
50 os.unlink(path)
51
52 def process(self):
53 inputs = torch.load(self.raw_paths[0])
54 # data_list = [Data(**data_dict) for data_dict in inputs]
55
56 data_list = []
-
E203
Whitespace before ':'
57 for data_dict in inputs :
58 data_dict['x'] = data_dict['x'][:, :1]
59 data_list.append(Data(**data_dict))
60
61 if self.pre_filter is not None:
62 data_list = [d for d in data_list if self.pre_filter(d)]
63
64 if self.pre_transform is not None:
65 data_list = [self.pre_transform(d) for d in data_list]
66
67 torch.save(self.collate(data_list), self.processed_paths[0])