utils_ids

gds/datasets/gsn/utils_ids.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import numpy as np
from .utils_graph_processing import subgraph_isomorphism_vertex_counts, subgraph_isomorphism_edge_counts
from torch_geometric.utils import remove_self_loops


def subgraph_counts2ids(count_fn, data, subgraph_dicts, subgraph_params):
    
    #### Remove self loops and then assign the structural identifiers by computing subgraph isomorphisms ####
    
    if hasattr(data, 'edge_features'):
        edge_index, edge_features = remove_self_loops(data.edge_index, data.edge_features)
        setattr(data, 'edge_features', edge_features)
    else:
        edge_index = remove_self_loops(data.edge_index)[0]
             
    num_nodes = data.x.shape[0]
    identifiers = None
    for subgraph_dict in subgraph_dicts:
        kwargs = {'subgraph_dict': subgraph_dict, 
                  'induced': subgraph_params['induced'],
                  'num_nodes': num_nodes,
                  'directed': subgraph_params['directed']}
        counts = count_fn(edge_index, **kwargs)
        identifiers = counts if identifiers is None else torch.cat((identifiers, counts),1) 
       
    setattr(data, 'edge_index', edge_index)
    setattr(data, 'identifiers', identifiers.long())
    
    return data