experiments.models.mlp

experiments/models/mlp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
import torch.nn.functional as F
from ogb.graphproppred.mol_encoder import AtomEncoder
from torch_geometric.nn import global_mean_pool

class MLP(torch.nn.Module):
    def __init__(self, gnn_type, dataset_group, num_tasks=128, num_layers=5, emb_dim=300, dropout=0.5, is_pooled=True,
                 **model_kwargs):
        """
        Args:
            - num_tasks (int): number of binary label tasks. default to 128 (number of tasks of ogbg-molpcba)
            - num_layers (int): number of message passing layers of GNN
            - emb_dim (int): dimensionality of hidden channels
            - dropout (float): dropout ratio applied to hidden channels
        """
        self.dataset_group = dataset_group

        super(MLP, self).__init__()

        self.num_layers = num_layers
        self.dropout = dropout
        self.emb_dim = emb_dim
        self.num_tasks = num_tasks
        self.is_pooled = is_pooled

        if num_tasks is None:
            self.d_out = self.emb_dim
        else:
            self.d_out = self.num_tasks

        if self.num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")


        ##################################################################################
        if self.dataset_group == 'mol':
            self.node_encoder = AtomEncoder(emb_dim)
        elif self.dataset_group == 'ppa':
            self.node_encoder = torch.nn.Embedding(1, emb_dim) # uniform input node embedding
        elif self.dataset_group == 'RotatedMNIST':
            self.node_encoder = torch.nn.Linear(1, emb_dim)
        elif self.dataset_group == 'ColoredMNIST' :
            self.node_encoder = torch.nn.Linear(2, emb_dim)
            # self.node_encoder_cate = torch.nn.Embedding(8, emb_dim)
        elif self.dataset_group == 'SBM' :
            self.node_encoder = torch.nn.Embedding(8, emb_dim)
        elif self.dataset_group == 'UPFD':
            self.node_encoder = torch.nn.Embedding(8, emb_dim)
        else:
            raise NotImplementedError
        ###List of GNNs
        self.fcs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()

        for layer in range(num_layers):
            self.fcs.append(torch.nn.Linear(emb_dim, emb_dim))
            self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
        ##################################################################################

        self.pool = global_mean_pool
        # Pooling function to generate whole-graph embeddings
        if num_tasks is None:
            self.graph_pred_linear = None
        else:
            self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_tasks)

    def forward(self, batched_data):
        x = batched_data.x
        # if self.dataset_group == 'ColoredMNIST' :
        #     # x = self.node_encoder(x[:, :2]) + self.node_encoder_cate(x[:, 2:].to(torch.int).squeeze())
        #     x = self.node_encoder(x[:, :2])
        # else :
        x = self.node_encoder(x)
        for i in range(self.num_layers) :
            x = self.fcs[i](x)
            x = self.batch_norms[i](x)
            if i == self.num_layers - 1:
                x = F.dropout(x, self.dropout, training=self.training)
            else:
                x = F.dropout(F.relu(x), self.dropout, training=self.training)
        x = self.pool(x, batched_data.batch)

        if self.graph_pred_linear is None:
            return x
        else:
            return self.graph_pred_linear(x)