preprocessing.scaffold.smiles2scaffold

preprocessing/scaffold/smiles2scaffold.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import numpy as np
from rdkit.Chem.Scaffolds import MurckoScaffold
from rdkit.Chem.Scaffolds import MurckoScaffold

RDLogger.DisableLog('rdApp.*')


def generate_scaffold(smiles, include_chirality=False):
    """
    Obtain Bemis-Murcko scaffold from smiles
    :param smiles:
    :param include_chirality:
    :return: smiles of scaffold
    """
    scaffold = MurckoScaffold.MurckoScaffoldSmiles(
        smiles=smiles, includeChirality=include_chirality)
    return scaffold


# # test generate scaffold
# s = 'Cc1cc(Oc2nccc(CCC)c2)ccc1'
# scaffold = generate_scaffold(s)
# print(scaffold)
# assert scaffold == 'c1ccc(Oc2ccccn2)cc1'


def scaffold_split(smiles_list, frac_train=0.8, frac_valid=0.1, frac_test=0.1):
    """
    Adapted from https://github.com/deepchem/deepchem/blob/master/deepchem/splits/splitters.py
    Split dataset by Bemis-Murcko scaffolds. Deterministic split
    :param smiles_list: list of smiles
    :param frac_train:
    :param frac_valid:
    :param frac_test:
    :return: list of train, valid, test indices corresponding to the
    scaffold split
    """
    np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.0)

    # create dict of the form {scaffold_i: [idx1, idx....]}
    all_scaffolds = {}
    for i, smiles in enumerate(smiles_list):
        scaffold = generate_scaffold(smiles, include_chirality=True)

        if scaffold not in all_scaffolds:
            all_scaffolds[scaffold] = [i]
        else:
            all_scaffolds[scaffold].append(i)

    # sort from largest to smallest sets
    all_scaffolds = {key: sorted(value) for key, value in all_scaffolds.items()}
    all_scaffold_sets = [
        scaffold_set for (scaffold, scaffold_set) in sorted(
            all_scaffolds.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True)
    ]

    # get train, valid test indices
    train_cutoff = frac_train * len(smiles_list)
    valid_cutoff = (frac_train + frac_valid) * len(smiles_list)
    train_idx, valid_idx, test_idx = [], [], []
    for scaffold_set in all_scaffold_sets:
        if len(train_idx) + len(scaffold_set) > train_cutoff:
            if len(train_idx) + len(valid_idx) + len(scaffold_set) > valid_cutoff:
                test_idx.extend(scaffold_set)
            else:
                valid_idx.extend(scaffold_set)
        else:
            train_idx.extend(scaffold_set)

    assert len(set(train_idx).intersection(set(valid_idx))) == 0
    assert len(set(train_idx).intersection(set(test_idx))) == 0
    assert len(set(test_idx).intersection(set(valid_idx))) == 0

    scaffold_group = np.zeros(len(smiles_list), dtype=np.int64)
    for i, index in enumerate(all_scaffold_sets):
        scaffold_group[index] = i

    np.save('scaffold_group', scaffold_group)

    return train_idx, valid_idx, test_idx


if __name__ == '__main__':
    df = pd.read_csv('mol.csv.gz')
    smiles_list = df['smiles'].tolist()
    train_idx, valid_idx, test_idx = scaffold_split(smiles_list)