⬅ datasets/pyg_upfd_dataset.py source

1 import os
2 import os.path as osp
3 import torch
4 import numpy as np
5 import scipy.sparse as sp
  • E501 Line too long (81 > 79 characters)
6 from torch_geometric.data import InMemoryDataset, download_url, extract_zip, Data
  • F401 'torch.nn.functional as F' imported but unused
7 import torch.nn.functional as F
8 from torch_sparse import coalesce
9 from torch_geometric.io import read_txt_array
10  
11 import pdb
12  
  • E302 Expected 2 blank lines, found 1
13 class PyGUPFDDataset(InMemoryDataset):
14 names = ['UPFD']
15 urls = {
16 'UPFD': 'https://www.dropbox.com/s/fgr5uzishps3mv3/UPFD.zip?dl=1'
17 }
18  
19 def __init__(self, root, name, feature, transform=None,
20 pre_transform=None, pre_filter=None):
21 self.root = root
22 self.name = name
23 assert self.name in self.names
24 self.feature = feature
25 if name == 'UPFD':
26 self.num_tasks = 1
27 self.__num_classes__ = 1
28 else:
29 raise NotImplementedError
  • E501 Line too long (88 > 79 characters)
30 super(PyGUPFDDataset, self).__init__(root, transform, pre_transform, pre_filter)
31  
32 path = self.processed_paths[0]
33 self.data, self.slices = torch.load(path)
34  
35 @property
36 def raw_dir(self):
37 return osp.join(self.root, self.name, 'raw')
38  
39 @property
40 def processed_dir(self):
41 return osp.join(self.root, self.name, 'processed', self.feature)
42  
43 @property
44 def raw_file_names(self):
45 return [
46 'node_graph_id.npy', 'graph_labels.npy', 'A.txt', 'train_idx.npy',
47 'val_idx.npy', 'test_idx.npy', f'new_{self.feature}_feature.npz'
48 ]
49  
50 @property
51 def processed_file_names(self):
52 return ['all.pt']
53  
54 def download(self):
55 path = download_url(self.urls[self.name], self.raw_dir)
56 extract_zip(path, self.raw_dir)
57 os.unlink(path)
58  
59 def process(self):
60 x = sp.load_npz(
61 osp.join(self.raw_dir, f'new_{self.feature}_feature.npz'))
62 x = torch.from_numpy(x.todense()).to(torch.float)
63  
64 edge_index = read_txt_array(osp.join(self.raw_dir, 'A.txt'), sep=',',
65 dtype=torch.long).t()
66 edge_index, _ = coalesce(edge_index, None, x.size(0), x.size(0))
67  
68 y = np.load(osp.join(self.raw_dir, 'graph_labels.npy'))
69 y = torch.from_numpy(y).to(torch.long)
70 _, y = y.unique(sorted=True, return_inverse=True)
71  
72 batch = np.load(osp.join(self.raw_dir, 'node_graph_id.npy'))
73 batch = torch.from_numpy(batch).to(torch.long)
74  
75 node_slice = torch.cumsum(batch.bincount(), 0)
76 node_slice = torch.cat([torch.tensor([0]), node_slice])
77 edge_slice = torch.cumsum(batch[edge_index[0]].bincount(), 0)
78 edge_slice = torch.cat([torch.tensor([0]), edge_slice])
79 graph_slice = torch.arange(y.size(0) + 1)
80 self.slices = {
81 'x': node_slice,
82 'edge_index': edge_slice,
83 'y': graph_slice
84 }
85  
86 edge_index -= node_slice[batch[edge_index[0]]].view(1, -1)
87  
88 np.random.seed(0)
89 x = np.random.randint(8, size=(x.shape[0],))
90 x = torch.from_numpy(x).long()
91  
92 assert x.dim() == 1
93 assert x.dtype == torch.int64
94  
95 self.data = Data(x=x, edge_index=edge_index, y=y)
96  
97 idx = []
98 for split in ['train', 'val', 'test']:
99 idx += np.load(osp.join(self.raw_dir, f'{split}_idx.npy')).tolist()
100 data_list = [self.get(i) for i in idx]
101 if self.pre_filter is not None:
102 data_list = [d for d in data_list if self.pre_filter(d)]
103 if self.pre_transform is not None:
104 data_list = [self.pre_transform(d) for d in data_list]
105 torch.save(self.collate(data_list), self.processed_paths[0])
106  
107 def __repr__(self):
108 return (f'{self.__class__.__name__}({len(self)}, name={self.name}, '
109 f'feature={self.feature})')
110  
111  
  • E203 Whitespace before ':'
112 if __name__ == '__main__' :
113 root = '/cmlscratch/kong/datasets/graph_domain'
114 PyGUPFDDataset(root, 'UPFD', 'profile')
115  
  • F811 Redefinition of unused 'pdb' from line 11
116 import pdb
117 pdb.set_trace()
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
  • E303 Too many blank lines (27)
145 ########################################################################
146 # class Net(torch.nn.Module):
147 # def __init__(self, model, in_channels, hidden_channels, out_channels,
148 # concat=False):
149 # super(Net, self).__init__()
150 # self.concat = concat
151 #
152 # if model == 'GCN':
153 # self.conv1 = GCNConv(in_channels, hidden_channels)
154 # elif model == 'SAGE':
155 # self.conv1 = SAGEConv(in_channels, hidden_channels)
156 # elif model == 'GAT':
157 # self.conv1 = GATConv(in_channels, hidden_channels)
158 #
159 # if self.concat:
160 # self.lin0 = Linear(in_channels, hidden_channels)
161 # self.lin1 = Linear(2 * hidden_channels, hidden_channels)
162 #
163 # self.lin2 = Linear(hidden_channels, out_channels)
164 #
165 # def forward(self, x, edge_index, batch):
166 # h = self.conv1(x, edge_index).relu()
167 # h = global_max_pool(h, batch)
168 #
169 # if self.concat:
170 # # Get the root node (tweet) features of each graph:
171 # root = (batch[1:] - batch[:-1]).nonzero(as_tuple=False).view(-1)
172 # root = torch.cat([root.new_zeros(1), root + 1], dim=0)
173 # news = x[root]
174 #
175 # news = self.lin0(news).relu()
176 # h = self.lin1(torch.cat([news, h], dim=-1)).relu()
177 #
178 # h = self.lin2(h)
179 # return h.log_softmax(dim=-1)