1 import torch.nn as nn
2 from torch_geometric.nn import global_mean_pool
3
4 from models.gnn import GNN
5 from models.three_wl import ThreeWLGNNNet
6 from models.gsn.gnn import GNN_GSN
7 from models.mlp import MLP
8
9
-
E501
Line too long (117 > 79 characters)
10 def initialize_model(config, d_out, is_featurizer=False, full_dataset=None, is_pooled=True, include_projector=False):
11 """
12 Initializes models according to the config
13 Args:
14 - config (dictionary): config dictionary
15 - d_out (int): the dimensionality of the model output
-
E501
Line too long (122 > 79 characters)
16 - is_featurizer (bool): whether to return a model or a (featurizer, classifier) pair that constitutes a model.
17 Output:
18 If is_featurizer=True:
-
E501
Line too long (113 > 79 characters)
19 - featurizer: a model that outputs feature Tensors of shape (batch_size, ..., feature dimensionality)
-
E501
Line too long (127 > 79 characters)
20 - classifier: a model that takes in feature Tensors and outputs predictions. In most cases, this is a linear layer.
21 If is_featurizer=False:
-
E501
Line too long (88 > 79 characters)
22 - model: a model that is equivalent to nn.Sequential(featurizer, classifier)
23 """
24
25 if full_dataset is None:
26 if config.model == "3wlgnn":
27 if is_featurizer:
-
E501
Line too long (104 > 79 characters)
28 featurizer = ThreeWLGNNNet(gnn_type=config.model, num_tasks=None, **config.model_kwargs)
29 classifier = nn.Linear(featurizer.d_out, d_out)
30 model = (featurizer, classifier)
31 else:
-
E501
Line too long (100 > 79 characters)
32 model = ThreeWLGNNNet(gnn_type=config.model, num_tasks=d_out, **config.model_kwargs)
-
E203
Whitespace before ':'
33 elif config.model == 'mlp' :
-
E261
At least two spaces before inline comment
-
E501
Line too long (122 > 79 characters)
34 assert config.algorithm == 'ERM' or config.algorithm == 'IRM' # combinations with other algorithms not checked
35 if is_featurizer:
-
E501
Line too long (94 > 79 characters)
36 featurizer = MLP(gnn_type=config.model, num_tasks=None, **config.model_kwargs)
37 classifier = nn.Linear(featurizer.d_out, d_out)
38 model = (featurizer, classifier)
39 else:
-
E501
Line too long (90 > 79 characters)
40 model = MLP(gnn_type=config.model, num_tasks=d_out, **config.model_kwargs)
41 else:
42 if is_featurizer:
-
E203
Whitespace before ':'
43 if is_pooled :
-
E501
Line too long (119 > 79 characters)
44 featurizer = GNN(gnn_type=config.model, num_tasks=None, is_pooled=is_pooled, **config.model_kwargs)
45 classifier = nn.Linear(featurizer.d_out, d_out)
46 model = (featurizer, classifier)
-
E203
Whitespace before ':'
47 else :
-
E501
Line too long (119 > 79 characters)
48 featurizer = GNN(gnn_type=config.model, num_tasks=None, is_pooled=is_pooled, **config.model_kwargs)
49 classifier = nn.Linear(featurizer.d_out, d_out)
50 pooler = global_mean_pool
51 model = (featurizer, pooler, classifier)
52 else:
-
E501
Line too long (111 > 79 characters)
53 model = GNN(gnn_type=config.model, num_tasks=d_out, is_pooled=is_pooled, **config.model_kwargs)
54
55 # We use the full dataset only for GSN
56 # Need to be refactored
57 else:
58 if is_featurizer:
59 featurizer = GNN_GSN(in_features=full_dataset.num_features,
60 out_features=None,
61 encoder_ids=full_dataset.encoder_ids,
62 d_in_id=full_dataset.d_id,
-
E501
Line too long (81 > 79 characters)
63 in_edge_features=full_dataset.num_edge_features,
-
E501
Line too long (82 > 79 characters)
64 d_in_node_encoder=full_dataset.d_in_node_encoder,
-
E501
Line too long (82 > 79 characters)
65 d_in_edge_encoder=full_dataset.d_in_edge_encoder,
66 d_degree=full_dataset.d_degree,
-
E501
Line too long (84 > 79 characters)
67 dataset_group=config.model_kwargs['dataset_group'])
68 classifier = nn.Linear(featurizer.d_out, d_out)
69 model = (featurizer, classifier)
70 else:
71 model = GNN_GSN(in_features=full_dataset.num_features,
72 out_features=d_out,
73 encoder_ids=full_dataset.encoder_ids,
74 d_in_id=full_dataset.d_id,
75 in_edge_features=full_dataset.num_edge_features,
76 d_in_node_encoder=full_dataset.d_in_node_encoder,
77 d_in_edge_encoder=full_dataset.d_in_edge_encoder,
78 d_degree=full_dataset.d_degree,
79 dataset_group=config.model_kwargs['dataset_group'])
80
81 # The projector head is used to construct the inputs to
82 # the similarity loss function for GCL, based on simclr
83 # Assumes from above is model=(featurizer, classifier)
84 # Usage of (featurizer, projector, classifier) in
85 # GCL algorithm class will be similar to deepCORAL
87 assert config.algorithm == 'GCL', 'The projector component ' \
88 'is only used with GCL'
89 assert is_featurizer, 'Need pre-packing of (featurizer, classifier) ' \
90 'into model to add projector'
91 assert is_pooled, 'Expects only (featurizer, classifier), ' \
-
E501
Line too long (93 > 79 characters)
92 'not (featurizer, pooler, classifier), i.e. whole graph embeddings'
93 graph_embedding_dim = featurizer.d_out
-
E501
Line too long (162 > 79 characters)
94 projector = nn.Sequential(nn.Linear(graph_embedding_dim, graph_embedding_dim), nn.ReLU(inplace=True), nn.Linear(graph_embedding_dim, graph_embedding_dim))
95 featurizer, classifier = model
-
E501
Line too long (90 > 79 characters)
96 model = (featurizer, projector, classifier, nn.Sequential(featurizer, classifier))
-
W293
Blank line contains whitespace
97
98 # The `needs_y` attribute specifies whether the model's forward function
99 # needs to take in both (x, y).
100 # If False, Algorithm.process_batch will call model(x).
101 # If True, Algorithm.process_batch() will call model(x, y) during training,
102 # and model(x, None) during eval.
103 if not hasattr(model, 'needs_y'):
104 # Sometimes model is a tuple of (featurizer, classifier, ...)
105 if isinstance(model, tuple):
106 for submodel in model:
107 submodel.needs_y = False
108 else:
109 model.needs_y = False
110
111 return model