1 import os
2 import os.path as osp
3 import torch
-
E501
Line too long (81 > 79 characters)
4 from torch_geometric.data import InMemoryDataset, download_url, extract_zip, Data
5 from typing import Optional, Callable, List
6
7
8 class PyGSBMEnvironmentDataset(InMemoryDataset):
9 names = ['SBM-Environment']
10 urls = {
-
E501
Line too long (95 > 79 characters)
11 'SBM-Environment': 'https://www.dropbox.com/s/rngulpfawv2hmlr/SBM-Environment.zip?dl=1'
12 }
13
14 def __init__(self, root: str, name: str,
15 transform: Optional[Callable] = None,
16 pre_transform: Optional[Callable] = None,
17 pre_filter: Optional[Callable] = None):
18 self.name = name
19 assert self.name in self.names
20 if name == 'SBM-Environment':
21 self.num_tasks = 1
22 self.__num_classes__ = 3
23 else:
24 raise NotImplementedError
25
26 super().__init__(root, transform, pre_transform, pre_filter)
27 self.data, self.slices = torch.load(self.processed_paths[0])
28
29 @property
30 def raw_dir(self) -> str:
31 return osp.join(self.root, self.name, 'raw')
32
33 @property
34 def processed_dir(self) -> str:
35 return osp.join(self.root, self.name, 'processed')
36
37 @property
38 def raw_file_names(self) -> List[str]:
39 # raw file name must match the dataset name
40 return [f'{self.name}.pt']
41
42 @property
43 def processed_file_names(self) -> List[str]:
44 return [f'{self.name}_processed.pt']
45
46 def download(self):
47 path = download_url(self.urls[self.name], self.raw_dir)
48 extract_zip(path, self.raw_dir)
49 os.unlink(path)
50
51 def process(self):
52 inputs = torch.load(self.raw_paths[0])
53
54 data_list = [Data(**data_dict) for data_dict in inputs]
55
56 if self.pre_filter is not None:
57 data_list = [d for d in data_list if self.pre_filter(d)]
58
59 if self.pre_transform is not None:
60 data_list = [self.pre_transform(d) for d in data_list]
61
62 torch.save(self.collate(data_list), self.processed_paths[0])