experiments.models.initializer

experiments/models/initializer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch.nn as nn
from torch_geometric.nn import global_mean_pool

from models.gnn import GNN
from models.three_wl import ThreeWLGNNNet
from models.gsn.gnn import GNN_GSN
from models.mlp import MLP


def initialize_model(config, d_out, is_featurizer=False, full_dataset=None, is_pooled=True, include_projector=False):
    """
    Initializes models according to the config
        Args:
            - config (dictionary): config dictionary
            - d_out (int): the dimensionality of the model output
            - is_featurizer (bool): whether to return a model or a (featurizer, classifier) pair that constitutes a model.
        Output:
            If is_featurizer=True:
            - featurizer: a model that outputs feature Tensors of shape (batch_size, ..., feature dimensionality)
            - classifier: a model that takes in feature Tensors and outputs predictions. In most cases, this is a linear layer.
            If is_featurizer=False:
            - model: a model that is equivalent to nn.Sequential(featurizer, classifier)
    """

    if full_dataset is None:
        if config.model == "3wlgnn":
            if is_featurizer:
                featurizer = ThreeWLGNNNet(gnn_type=config.model, num_tasks=None, **config.model_kwargs)
                classifier = nn.Linear(featurizer.d_out, d_out)
                model = (featurizer, classifier)
            else:
                model = ThreeWLGNNNet(gnn_type=config.model, num_tasks=d_out, **config.model_kwargs)
        elif config.model == 'mlp' :
            assert config.algorithm == 'ERM' or config.algorithm == 'IRM' # combinations with other algorithms not checked
            if is_featurizer:
                featurizer = MLP(gnn_type=config.model, num_tasks=None, **config.model_kwargs)
                classifier = nn.Linear(featurizer.d_out, d_out)
                model = (featurizer, classifier)
            else:
                model = MLP(gnn_type=config.model, num_tasks=d_out, **config.model_kwargs)
        else:
            if is_featurizer:
                if is_pooled :
                    featurizer = GNN(gnn_type=config.model, num_tasks=None, is_pooled=is_pooled, **config.model_kwargs)
                    classifier = nn.Linear(featurizer.d_out, d_out)
                    model = (featurizer, classifier)
                else :
                    featurizer = GNN(gnn_type=config.model, num_tasks=None, is_pooled=is_pooled, **config.model_kwargs)
                    classifier = nn.Linear(featurizer.d_out, d_out)
                    pooler = global_mean_pool
                    model = (featurizer, pooler, classifier)
            else:
                model = GNN(gnn_type=config.model, num_tasks=d_out, is_pooled=is_pooled, **config.model_kwargs)

    # We use the full dataset only for GSN
    # Need to be refactored
    else:
        if is_featurizer:
            featurizer = GNN_GSN(in_features=full_dataset.num_features,
                                 out_features=None,
                                 encoder_ids=full_dataset.encoder_ids,
                                 d_in_id=full_dataset.d_id,
                                 in_edge_features=full_dataset.num_edge_features,
                                 d_in_node_encoder=full_dataset.d_in_node_encoder,
                                 d_in_edge_encoder=full_dataset.d_in_edge_encoder,
                                 d_degree=full_dataset.d_degree,
                                 dataset_group=config.model_kwargs['dataset_group'])
            classifier = nn.Linear(featurizer.d_out, d_out)
            model = (featurizer, classifier)
        else:
            model = GNN_GSN(in_features=full_dataset.num_features,
                            out_features=d_out,
                            encoder_ids=full_dataset.encoder_ids,
                            d_in_id=full_dataset.d_id,
                            in_edge_features=full_dataset.num_edge_features,
                            d_in_node_encoder=full_dataset.d_in_node_encoder,
                            d_in_edge_encoder=full_dataset.d_in_edge_encoder,
                            d_degree=full_dataset.d_degree,
                            dataset_group=config.model_kwargs['dataset_group'])

    # The projector head is used to construct the inputs to
    # the similarity loss function for GCL, based on simclr
    # Assumes from above is model=(featurizer, classifier)
    # Usage of (featurizer, projector, classifier) in 
    # GCL algorithm class will be similar to deepCORAL
    if include_projector: 
        assert config.algorithm == 'GCL', 'The projector component ' \
                                          'is only used with GCL'
        assert is_featurizer, 'Need pre-packing of (featurizer, classifier) ' \
                              'into model to add projector'
        assert is_pooled, 'Expects only (featurizer, classifier), ' \
                          'not (featurizer, pooler, classifier), i.e. whole graph embeddings'
        graph_embedding_dim = featurizer.d_out
        projector = nn.Sequential(nn.Linear(graph_embedding_dim, graph_embedding_dim), nn.ReLU(inplace=True), nn.Linear(graph_embedding_dim, graph_embedding_dim))
        featurizer, classifier = model
        model = (featurizer, projector, classifier, nn.Sequential(featurizer, classifier))
    
    # The `needs_y` attribute specifies whether the model's forward function
    # needs to take in both (x, y).
    # If False, Algorithm.process_batch will call model(x).
    # If True, Algorithm.process_batch() will call model(x, y) during training,
    # and model(x, None) during eval.
    if not hasattr(model, 'needs_y'):
        # Sometimes model is a tuple of (featurizer, classifier, ...)
        if isinstance(model, tuple):
            for submodel in model:
                submodel.needs_y = False
        else:
            model.needs_y = False

    return model