⬅ datasets/gsn/utils_data_prep.py source

1 import os
2 import csv
3 import pickle
4 from collections import namedtuple
5 import networkx as nx
6 import numpy as np
  • F401 'random' imported but unused
7 import random
8 import torch
9 from sklearn.model_selection import StratifiedKFold
10 from ogb.graphproppred import PygGraphPropPredDataset
11 from torch_geometric.utils import to_undirected
12  
  • E302 Expected 2 blank lines, found 1
13 class S2VGraph(object):
14 def __init__(self, g, label, node_tags=None, node_features=None):
15 '''
16 Code obtained from here: https://github.com/weihua916/powerful-gnns
  • W293 Blank line contains whitespace
17
18 g: a networkx graph
19 label: an integer graph label
20 node_tags: a list of integer node tags
  • E501 Line too long (119 > 79 characters)
21 node_features: a torch float tensor, one-hot representation of the tag that is used as input to neural nets
  • E501 Line too long (104 > 79 characters)
22 edge_mat: a torch long tensor, contain edge list, will be used to create torch sparse tensor
23 neighbors: list of neighbors (without self-loop)
24 '''
25 self.label = label
26 self.g = g
27 self.node_tags = node_tags
28 self.neighbors = []
29 self.node_features = 0
30 self.edge_mat = 0
31  
32 self.max_neighbor = 0
33  
34  
35 def load_data(path, name, degree_as_tag):
36 '''
37 Code obtained from here: https://github.com/weihua916/powerful-gnns
  • W293 Blank line contains whitespace
38
39 dataset: name of dataset
40 test_proportion: ratio of test train split
41 seed: random seed for random splitting of dataset
42 '''
43  
44 print('loading data')
45 g_list = []
46 label_dict = {}
47 feat_dict = {}
48  
49 with open('%s/%s.txt' % (path, name), 'r') as f:
50 n_g = int(f.readline().strip())
51 for i in range(n_g):
52 row = f.readline().strip().split()
  • E741 Ambiguous variable name 'l'
53 n, l = [int(w) for w in row]
  • E713 Test for membership should be 'not in'
54 if not l in label_dict:
55 mapped = len(label_dict)
56 label_dict[l] = mapped
57 g = nx.Graph()
58 node_tags = []
59 node_features = []
60 n_edges = 0
61 for j in range(n):
62 g.add_node(j)
63 row = f.readline().strip().split()
64 tmp = int(row[1]) + 2
65 if tmp == len(row):
66 # no node attributes
67 row = [int(w) for w in row]
68 attr = None
69 else:
  • E501 Line too long (100 > 79 characters)
70 row, attr = [int(w) for w in row[:tmp]], np.array([float(w) for w in row[tmp:]])
71 if not row[0] in feat_dict:
72 mapped = len(feat_dict)
73 feat_dict[row[0]] = mapped
74 node_tags.append(feat_dict[row[0]])
75  
76 if tmp > len(row):
77 node_features.append(attr)
78  
79 n_edges += row[1]
80 for k in range(2, len(row)):
81 g.add_edge(j, row[k])
82  
83 if node_features != []:
84 node_features = np.stack(node_features)
85 node_feature_flag = True
86 else:
87 node_features = None
  • F841 Local variable 'node_feature_flag' is assigned to but never used
88 node_feature_flag = False
89  
90 assert len(g) == n
91  
92 g_list.append(S2VGraph(g, l, node_tags))
93  
  • E265 Block comment should start with '# '
  • W291 Trailing whitespace
94 #add labels and edge_mat
95 for g in g_list:
96 g.neighbors = [[] for i in range(len(g.g))]
97 for i, j in g.g.edges():
98 g.neighbors[i].append(j)
99 g.neighbors[j].append(i)
100 degree_list = []
101 for i in range(len(g.g)):
102 g.neighbors[i] = g.neighbors[i]
103 degree_list.append(len(g.neighbors[i]))
104 g.max_neighbor = max(degree_list)
105  
106 g.label = label_dict[g.label]
107  
108 edges = [list(pair) for pair in g.g.edges()]
109 edges.extend([[i, j] for j, i in edges])
110  
  • F841 Local variable 'deg_list' is assigned to but never used
111 deg_list = list(dict(g.g.degree(range(len(g.g)))).values())
  • E231 Missing whitespace after ','
112 g.edge_mat = torch.LongTensor(edges).transpose(0,1)
113  
114 if degree_as_tag:
115 for g in g_list:
116 g.node_tags = list(dict(g.g.degree).values())
117  
  • E265 Block comment should start with '# '
  • W291 Trailing whitespace
118 #Extracting unique tag labels
119 tagset = set([])
120 for g in g_list:
121 tagset = tagset.union(set(g.node_tags))
122  
123 tagset = list(tagset)
  • E231 Missing whitespace after ':'
124 tag2index = {tagset[i]:i for i in range(len(tagset))}
125  
126 for g in g_list:
127 g.node_features = torch.zeros(len(g.node_tags), len(tagset))
  • E501 Line too long (93 > 79 characters)
128 g.node_features[range(len(g.node_tags)), [tag2index[tag] for tag in g.node_tags]] = 1
129  
130  
  • E303 Too many blank lines (2)
131 print('# classes: %d' % len(label_dict))
132 print('# maximum node tag: %d' % len(tagset))
133  
134 print("# data: %d" % len(g_list))
135  
136 return g_list, len(label_dict)
137  
138  
  • E501 Line too long (81 > 79 characters)
139 def load_zinc_data(path, name, degree_as_tag, num_atom_type=28, num_bond_type=4):
  • W293 Blank line contains whitespace
140
  • E114 Indentation is not a multiple of 4 (comment)
  • E117 Over-indented (comment)
  • E266 Too many leading '#' for block comment
  • E501 Line too long (148 > 79 characters)
141 ### splits and preprocessing according to https://github.com/graphdeeplearning/benchmarking-gnns/blob/master/main_molecules_graph_regression.py
  • W293 Blank line contains whitespace
142
143 assert name.upper() == 'ZINC'
  • E501 Line too long (88 > 79 characters)
144 Graph = namedtuple('Graph', ['node_features', 'edge_mat', 'edge_features', 'label'])
  • W293 Blank line contains whitespace
145
146 def _prepare(molecule):
  • W293 Blank line contains whitespace
147
148 node_features = molecule['atom_type'].long()
  • W293 Blank line contains whitespace
149
150 adj = molecule['bond_type']
151 edge_list = (adj != 0).nonzero() # converting adj matrix to edge_list
152 edge_idxs_in_adj = edge_list.split(1, dim=1)
153 edge_features = adj[edge_idxs_in_adj].reshape(-1).long()
  • W293 Blank line contains whitespace
154
155 label = molecule['logP_SA_cycle_normalized']
  • E501 Line too long (83 > 79 characters)
156 graph = Graph(node_features, edge_list.permute(1, 0), edge_features, label)
157  
158 return graph
159  
160 data = list()
161 for split_name in ['train', 'val', 'test']:
  • E231 Missing whitespace after ',' (in 2 places)
  • E501 Line too long (92 > 79 characters)
162 with open(os.path.join(path,'molecules','{}.pickle'.format(split_name)), "rb") as f:
163 split_data = pickle.load(f)
  • W293 Blank line contains whitespace
164
165 # loading the sampled indices from file ./zinc_molecules/<split>.index
  • E501 Line too long (90 > 79 characters)
166 with open(os.path.join(path, 'indices', '{}.index'.format(split_name)), "r") as f:
167 data_idx = [list(map(int, idx)) for idx in csv.reader(f)]
  • W293 Blank line contains whitespace
168
  • E201 Whitespace after '['
  • E202 Whitespace before ']'
169 split_data = [ split_data[i] for i in data_idx[0] ]
  • W293 Blank line contains whitespace
170
171 for molecule in split_data:
172 data.append(_prepare(molecule))
173  
174 return data, 1, num_atom_type, num_bond_type
175  
176  
177 def load_ogb_data(path, name, degree_as_tag):
  • W293 Blank line contains whitespace
178
  • E114 Indentation is not a multiple of 4 (comment)
  • E117 Over-indented (comment)
  • E266 Too many leading '#' for block comment
  • E501 Line too long (83 > 79 characters)
179 ### splits and preprocessing according to https://github.com/snap-stanford/ogb
  • W293 Blank line contains whitespace
180
181 def add_zeros(data):
182 data.x = torch.zeros(data.num_nodes, dtype=torch.long)
183 return data
  • W293 Blank line contains whitespace
184
185 if name == 'ogbg-ppa':
186 transform = add_zeros
187 print('Applying transform {} to dataset {}.'.format(transform, name))
  • E501 Line too long (84 > 79 characters)
188 dataset = PygGraphPropPredDataset(name=name, root=path, transform=transform)
189 else:
190 dataset = PygGraphPropPredDataset(name=name, root=path)
  • E501 Line too long (88 > 79 characters)
191 Graph = namedtuple('Graph', ['node_features', 'edge_mat', 'edge_features', 'label'])
192 graph_list = list()
193 for datum in dataset:
194 graph = Graph(datum.x, datum.edge_index, datum.edge_attr, datum.y)
195 graph_list.append(graph)
  • E501 Line too long (82 > 79 characters)
196 num_classes = dataset.num_classes if name == 'ogbg-ppa' else dataset.num_tasks
197 return graph_list, num_classes
198  
199  
200 def load_g6_graphs(path, name):
  • W293 Blank line contains whitespace
201
  • E114 Indentation is not a multiple of 4 (comment)
  • E117 Over-indented (comment)
  • E266 Too many leading '#' for block comment
  • E501 Line too long (106 > 79 characters)
202 ### code used to load SR graphs obtained from here http://users.cecs.anu.edu.au/~bdm/data/graphs.html
  • E114 Indentation is not a multiple of 4 (comment)
  • E117 Over-indented (comment)
  • E266 Too many leading '#' for block comment
  • E501 Line too long (130 > 79 characters)
203 ### we don't split the data, because no training is performed (the network is used with random weights for the SR experiment)
204  
205 dataset = nx.read_graph6(os.path.join(path, name+'.g6'))
  • E231 Missing whitespace after ','
206 Graph = namedtuple('Graph', ['node_features', 'edge_mat','label'])
207 graph_list = list()
  • E231 Missing whitespace after ','
208 for i,datum in enumerate(dataset):
  • E231 Missing whitespace after ','
209 x = torch.ones(datum.number_of_nodes(),1)
  • E501 Line too long (84 > 79 characters)
  • E231 Missing whitespace after ','
210 edge_index = to_undirected(torch.tensor(list(datum.edges())).transpose(1,0))
211 graph = Graph(x, edge_index, torch.tensor(i).long())
212 graph_list.append(graph)
213 num_classes = len(dataset)
  • W293 Blank line contains whitespace
214
215 return graph_list, num_classes
216  
217  
218 def separate_data(graph_list, seed, fold_idx):
  • W293 Blank line contains whitespace
219
  • E266 Too many leading '#' for block comment
220 ### Code obtained from here: https://github.com/weihua916/powerful-gnns
  • W293 Blank line contains whitespace
221
222 assert 0 <= fold_idx and fold_idx < 10, "fold_idx must be from 0 to 9."
  • E251 Unexpected spaces around keyword / parameter equals (in 4 places)
223 skf = StratifiedKFold(n_splits=10, shuffle = True, random_state = seed)
224  
225 if hasattr(graph_list[0], 'label'):
226 labels = [graph.label for graph in graph_list]
227 elif hasattr(graph_list[0], 'y'):
228 labels = [graph.y for graph in graph_list]
229 else:
230 raise NotImplementedError
  • W293 Blank line contains whitespace
231
232 idx_list = []
233 for idx in skf.split(np.zeros(len(labels)), labels):
234 idx_list.append(idx)
235 train_idx, test_idx = idx_list[fold_idx]
236  
237 train_graph_list = [graph_list[i] for i in train_idx]
238 test_graph_list = [graph_list[i] for i in test_idx]
239  
240 return train_graph_list, test_graph_list
241  
  • E302 Expected 2 blank lines, found 1
242 def separate_data_given_split(graph_list, path, fold_idx):
  • W293 Blank line contains whitespace
243
  • E266 Too many leading '#' for block comment
244 ### Splits data based on pre-computed splits
  • W293 Blank line contains whitespace
245
  • E501 Line too long (143 > 79 characters)
246 assert -1 <= fold_idx and fold_idx < 10, "Parameter fold_idx must be from -1 to 9, with -1 referring to the special model selection split."
247  
  • E501 Line too long (92 > 79 characters)
248 train_filename = os.path.join(path, '10fold_idx', 'train_idx-{}.txt'.format(fold_idx+1))
  • E501 Line too long (90 > 79 characters)
249 test_filename = os.path.join(path, '10fold_idx', 'test_idx-{}.txt'.format(fold_idx+1))
  • E501 Line too long (88 > 79 characters)
250 val_filename = os.path.join(path, '10fold_idx', 'val_idx-{}.txt'.format(fold_idx+1))
251 train_idx = np.loadtxt(train_filename, dtype=int)
252 test_idx = np.loadtxt(test_filename, dtype=int)
  • W293 Blank line contains whitespace
253
254 train_graph_list = [graph_list[i] for i in train_idx]
255 test_graph_list = [graph_list[i] for i in test_idx]
  • W291 Trailing whitespace
256 val_graph_list = None
  • W293 Blank line contains whitespace
257
258 if os.path.exists(val_filename):
259 val_idx = np.loadtxt(val_filename, dtype=int)
260 val_graph_list = [graph_list[i] for i in val_idx]
261  
262 return train_graph_list, test_graph_list, val_graph_list