1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
|
import os
import torch
import torch_geometric as torch_geo
from torch_geometric.utils import sort_edge_index, remove_self_loops
from torch_geometric.data import Data
from .utils_data_prep import load_data, load_zinc_data, load_ogb_data, load_g6_graphs
from .utils_graph_processing import automorphism_orbits
import multiprocessing as mp
import time
from .utils_misc import isnotebook
if isnotebook():
from tqdm import tqdm_notebook as tqdm
else:
from tqdm import tqdm
def generate_dataset(data_path,
dataset_name,
k,
extract_ids_fn,
count_fn,
automorphism_fn,
id_type,
multiprocessing=False,
num_processes=1,
**subgraph_params):
### compute the orbits of earch substructure in the list, as well as the vertex automorphism count
subgraph_dicts = []
orbit_partition_sizes = []
if 'edge_list' not in subgraph_params:
raise ValueError('Edge list not provided.')
for edge_list in subgraph_params['edge_list']:
subgraph, orbit_partition, orbit_membership, aut_count = \
automorphism_fn(edge_list=edge_list,
directed=subgraph_params['directed'],
directed_orbits=subgraph_params['directed_orbits'])
subgraph_dicts.append({'subgraph': subgraph, 'orbit_partition': orbit_partition,
'orbit_membership': orbit_membership, 'aut_count': aut_count})
orbit_partition_sizes.append(len(orbit_partition))
### load and preprocess dataset
data_path = os.path.join(data_path, os.pardir)
if dataset_name == 'ZINC':
graphs, num_classes, num_node_type, num_edge_type = load_zinc_data(data_path, dataset_name, False)
else:
graphs, num_classes = load_ogb_data(data_path, dataset_name, False)
num_node_type, num_edge_type = None, None
### parallel computation of subgraph isomoprhisms & creation of data structure
if multiprocessing:
print("Preparing dataset in parallel...")
start = time.time()
from joblib import delayed, Parallel
graphs_ptg = Parallel(n_jobs=num_processes, verbose=10)(delayed(_prepare)(graph,
subgraph_dicts,
subgraph_params,
dataset_name,
extract_ids_fn,
count_fn) for graph in graphs)
print('Done ({:.2f} secs).'.format(time.time() - start))
### single-threaded computation of subgraph isomoprhisms & creation of data structure
else:
graphs_ptg = list()
for i, data in tqdm(enumerate(graphs)):
new_data = _prepare(data, subgraph_dicts, subgraph_params, dataset_name, extract_ids_fn, count_fn)
graphs_ptg.append(new_data)
return graphs_ptg, num_classes, num_node_type, num_edge_type, orbit_partition_sizes
# ------------------------------------------------------------------------
def _prepare(data, subgraph_dicts, subgraph_params, dataset_name, ex_fn, cnt_fn):
new_data = Data()
setattr(new_data, 'edge_index', data.edge_mat)
setattr(new_data, 'x', data.node_features)
setattr(new_data, 'graph_size', data.node_features.shape[0])
if new_data.edge_index.shape[1] == 0:
setattr(new_data, 'degrees', torch.zeros((new_data.graph_size,)))
else:
setattr(new_data, 'degrees', torch_geo.utils.degree(new_data.edge_index[0]))
if hasattr(data, 'edge_features'):
setattr(new_data, 'edge_features', data.edge_features)
if dataset_name in {'ogbg-molpcba', 'ogbg-molhiv', 'ZINC'}:
setattr(new_data, 'y', data.label.clone().detach().unsqueeze(0).float())
else:
setattr(new_data, 'y', data.label.clone().detach().unsqueeze(0).long())
if new_data.edge_index.shape[1] == 0 and cnt_fn.__name__ == 'subgraph_isomorphism_edge_counts':
setattr(new_data, 'identifiers', torch.zeros((0, sum(orbit_partition_sizes))).long())
else:
new_data = ex_fn(cnt_fn, new_data, subgraph_dicts, subgraph_params)
return new_data
|