⬅ datasets/gsn/utils_data_gen.py source

1 import os
2 import torch
3 import torch_geometric as torch_geo
  • F401 'torch_geometric.utils.sort_edge_index' imported but unused
  • F401 'torch_geometric.utils.remove_self_loops' imported but unused
4 from torch_geometric.utils import sort_edge_index, remove_self_loops
5 from torch_geometric.data import Data
  • F401 '.utils_data_prep.load_data' imported but unused
  • F401 '.utils_data_prep.load_g6_graphs' imported but unused
  • E501 Line too long (85 > 79 characters)
6 from .utils_data_prep import load_data, load_zinc_data, load_ogb_data, load_g6_graphs
  • F401 '.utils_graph_processing.automorphism_orbits' imported but unused
7 from .utils_graph_processing import automorphism_orbits
8  
  • F401 'multiprocessing as mp' imported but unused
9 import multiprocessing as mp
10 import time
11  
12 from .utils_misc import isnotebook
13 if isnotebook():
14 from tqdm import tqdm_notebook as tqdm
15 else:
16 from tqdm import tqdm
17  
18  
19 def generate_dataset(data_path,
20 dataset_name,
21 k,
22 extract_ids_fn,
23 count_fn,
24 automorphism_fn,
25 id_type,
26 multiprocessing=False,
27 num_processes=1,
28 **subgraph_params):
  • E266 Too many leading '#' for block comment
  • E501 Line too long (102 > 79 characters)
29 ### compute the orbits of earch substructure in the list, as well as the vertex automorphism count
30  
31 subgraph_dicts = []
32 orbit_partition_sizes = []
33 if 'edge_list' not in subgraph_params:
34 raise ValueError('Edge list not provided.')
35 for edge_list in subgraph_params['edge_list']:
36 subgraph, orbit_partition, orbit_membership, aut_count = \
37 automorphism_fn(edge_list=edge_list,
38 directed=subgraph_params['directed'],
39 directed_orbits=subgraph_params['directed_orbits'])
  • E501 Line too long (88 > 79 characters)
40 subgraph_dicts.append({'subgraph': subgraph, 'orbit_partition': orbit_partition,
  • E501 Line too long (93 > 79 characters)
41 'orbit_membership': orbit_membership, 'aut_count': aut_count})
42 orbit_partition_sizes.append(len(orbit_partition))
43  
  • E266 Too many leading '#' for block comment
44 ### load and preprocess dataset
45 data_path = os.path.join(data_path, os.pardir)
46 if dataset_name == 'ZINC':
  • E501 Line too long (106 > 79 characters)
47 graphs, num_classes, num_node_type, num_edge_type = load_zinc_data(data_path, dataset_name, False)
48 else:
49 graphs, num_classes = load_ogb_data(data_path, dataset_name, False)
50 num_node_type, num_edge_type = None, None
51  
  • E266 Too many leading '#' for block comment
  • E501 Line too long (82 > 79 characters)
52 ### parallel computation of subgraph isomoprhisms & creation of data structure
53  
54 if multiprocessing:
55 print("Preparing dataset in parallel...")
56 start = time.time()
57 from joblib import delayed, Parallel
  • E501 Line too long (88 > 79 characters)
58 graphs_ptg = Parallel(n_jobs=num_processes, verbose=10)(delayed(_prepare)(graph,
  • E501 Line too long (97 > 79 characters)
59 subgraph_dicts,
  • E501 Line too long (98 > 79 characters)
60 subgraph_params,
  • E501 Line too long (95 > 79 characters)
61 dataset_name,
  • E501 Line too long (97 > 79 characters)
62 extract_ids_fn,
  • E501 Line too long (112 > 79 characters)
63 count_fn) for graph in graphs)
64 print('Done ({:.2f} secs).'.format(time.time() - start))
  • E266 Too many leading '#' for block comment
  • E501 Line too long (89 > 79 characters)
65 ### single-threaded computation of subgraph isomoprhisms & creation of data structure
66 else:
67 graphs_ptg = list()
68 for i, data in tqdm(enumerate(graphs)):
  • E501 Line too long (110 > 79 characters)
69 new_data = _prepare(data, subgraph_dicts, subgraph_params, dataset_name, extract_ids_fn, count_fn)
70  
71 graphs_ptg.append(new_data)
72  
  • E501 Line too long (87 > 79 characters)
73 return graphs_ptg, num_classes, num_node_type, num_edge_type, orbit_partition_sizes
74  
75  
76 # ------------------------------------------------------------------------
77  
  • E501 Line too long (81 > 79 characters)
78 def _prepare(data, subgraph_dicts, subgraph_params, dataset_name, ex_fn, cnt_fn):
79 new_data = Data()
80 setattr(new_data, 'edge_index', data.edge_mat)
81 setattr(new_data, 'x', data.node_features)
82 setattr(new_data, 'graph_size', data.node_features.shape[0])
83 if new_data.edge_index.shape[1] == 0:
84 setattr(new_data, 'degrees', torch.zeros((new_data.graph_size,)))
85 else:
  • E501 Line too long (84 > 79 characters)
86 setattr(new_data, 'degrees', torch_geo.utils.degree(new_data.edge_index[0]))
87  
88 if hasattr(data, 'edge_features'):
89 setattr(new_data, 'edge_features', data.edge_features)
90  
91 if dataset_name in {'ogbg-molpcba', 'ogbg-molhiv', 'ZINC'}:
  • E501 Line too long (80 > 79 characters)
92 setattr(new_data, 'y', data.label.clone().detach().unsqueeze(0).float())
93 else:
94 setattr(new_data, 'y', data.label.clone().detach().unsqueeze(0).long())
  • E501 Line too long (99 > 79 characters)
95 if new_data.edge_index.shape[1] == 0 and cnt_fn.__name__ == 'subgraph_isomorphism_edge_counts':
  • F821 Undefined name 'orbit_partition_sizes'
  • E501 Line too long (93 > 79 characters)
96 setattr(new_data, 'identifiers', torch.zeros((0, sum(orbit_partition_sizes))).long())
97 else:
98 new_data = ex_fn(cnt_fn, new_data, subgraph_dicts, subgraph_params)
99  
100 return new_data