Diff of /modas/coloc.py [000000] .. [a43cea]

Switch to unified view

a b/modas/coloc.py
1
import pandas as pd
2
import numpy as np
3
import bioframe as bf
4
from image_match.goldberg import ImageSignature
5
from pandas_plink import read_plink1_bin
6
from sklearn.cluster import DBSCAN
7
from scipy.spatial.distance import squareform
8
from scipy.cluster.hierarchy import linkage, leaves_list, cut_tree
9
from joblib import Parallel, delayed
10
from sklearn.metrics import silhouette_score, calinski_harabasz_score
11
import resource
12
import os
13
14
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, 4096))
15
16
17
def qtl_cluster(qtl):
18
    qtl.CHR = qtl.CHR.astype(str)
19
    qtl = bf.cluster(qtl, cols=['CHR', 'qtl_start', 'qtl_end'], min_dist=0)
20
    return qtl
21
22
23
def kin(g, top=2):
24
    g = g - g.mean()
25
    K = np.dot(g, g.T)
26
    d = np.diag(K)
27
    DL = np.min(d)
28
    DU = np.max(d)
29
    floor = np.min(K)
30
    K = top * (K - floor) / (DU - floor)
31
    Dmin = top * (DL - floor) / (DU - floor)
32
    dig_index = np.eye(K.shape[0], dtype=bool)
33
    if Dmin < 1:
34
        K[dig_index] = (np.diag(K)-Dmin+1)/((top+1-Dmin)*0.5)
35
        K[~dig_index] = K[~dig_index] * (1 / Dmin)
36
    Omax = np.max(K[~dig_index])
37
    if Omax > top:
38
        K[~dig_index] = K[~dig_index] * (top / Omax)
39
    return K
40
41
42
def get_kin_info(qtl, gwas_dir, geno, pvalue):
43
    var = pd.Series(geno.variant, index=geno.snp)
44
    kin_info = dict()
45
    for phe_name in qtl.phe_name.unique():
46
        fn = gwas_dir+'/tmp_' + phe_name + '_plink.assoc.txt'
47
        if not os.path.exists(fn):
48
            print('Warning: ' + fn + 'is not exist.')
49
            continue
50
        gwas = pd.read_csv(fn, sep='\t')
51
        geno_sub = geno.sel(variant=var.reindex(gwas.loc[gwas.p_wald <= pvalue, 'rs']).dropna().values, drop=True)
52
        geno_sub = pd.DataFrame(geno_sub.values, index=geno_sub.fid, columns=geno_sub.snp)
53
        kin_res = kin(geno_sub)
54
        ril_cluster = linkage(kin_res, method='ward')
55
        idx = leaves_list(ril_cluster)
56
        label = cut_tree(ril_cluster, n_clusters=2)[:, 0]
57
        kin_info[phe_name] = dict([['kin', kin_res], ['idx', idx], ['label', label]])
58
    return kin_info
59
60
61
def get_ril_cluster_idx(kin1, kin2, metric):
62
    score1 = calc_cluster_score(kin1['kin'], kin2['kin'], kin1['label'], metric)
63
    score2 = calc_cluster_score(kin1['kin'], kin2['kin'], kin2['label'], metric)
64
    if metric == 'silhouette':
65
        if score1 > score2:
66
            return kin1['idx']
67
        else:
68
            return kin2['idx']
69
    if metric == 'calinski_harabasz':
70
        if score1 > score2:
71
            return kin1['idx']
72
        else:
73
            return kin2['idx']
74
75
76
def calc_cluster_score(kin1, kin2, label, metric):
77
    if metric == 'silhouette':
78
        from sklearn.metrics import silhouette_score
79
        kin1_score = silhouette_score(kin1, label)
80
        kin2_score = silhouette_score(kin2, label)
81
    elif metric == 'calinski_harabasz':
82
        from sklearn.metrics import calinski_harabasz_score
83
        kin1_score = calinski_harabasz_score(kin1, label)
84
        kin2_score = calinski_harabasz_score(kin2, label)
85
    return np.mean([kin1_score, kin2_score])
86
87
88
def get_signature(g, gis):
89
    # image_limits = gis.crop_image(g,
90
    #                               lower_percentile=gis.lower_percentile,
91
    #                               upper_percentile=gis.upper_percentile,
92
    #                               fix_ratio=gis.fix_ratio)
93
    # x_coords, y_coords = gis.compute_grid_points(g,
94
    #                                              n=gis.n, window=image_limits)
95
    x_coords, y_coords = gis.compute_grid_points(g, n=gis.n)
96
    avg_grey = gis.compute_mean_level(g, x_coords, y_coords, P=gis.P)
97
    diff_mat = gis.compute_differentials(avg_grey,
98
                                         diagonal_neighbors=gis.diagonal_neighbors)
99
    gis.normalize_and_threshold(diff_mat,
100
                                identical_tolerance=gis.identical_tolerance,
101
                                n_levels=gis.n_levels)
102
    return np.ravel(diff_mat).astype('int8')
103
104
105
def calc_image_match_score(kin_info, phe_list, metric):
106
    gis = ImageSignature()
107
    score = list()
108
    for _, phe1 in enumerate(phe_list):
109
        for phe2 in phe_list[_+1:]:
110
            idx = get_ril_cluster_idx(kin_info[phe1], kin_info[phe2], metric)
111
            score.append(gis.normalized_distance(get_signature(kin_info[phe1]['kin'][idx, :][:, idx], gis), get_signature(kin_info[phe2]['kin'][idx, :][:, idx], gis)))
112
    score = squareform(score)
113
    return score
114
115
116
def cluster_coloc(kin_info, qtl, c, metric, cls):
117
    qtl_sub = qtl.loc[qtl.cluster==c, :]
118
    if qtl_sub.shape[0] > 1:
119
        phe_name = qtl_sub.phe_name.unique()
120
        trait_dis = calc_image_match_score(kin_info, phe_name, metric)
121
        trait_dis = np.round(trait_dis, 2)
122
        cls.fit(trait_dis)
123
        cls_res = pd.Series(cls.labels_, index=phe_name).to_frame().reset_index()
124
        cls_res.columns = ['phe_name', 'label']
125
        cls_res = cls_res.loc[cls_res.label != -1, :]
126
        if not cls_res.empty:
127
            cls_res['label'] = str(c) + '_' + cls_res.label.astype(str)
128
            return pd.DataFrame(trait_dis, index=phe_name, columns=phe_name), cls_res
129
        else:
130
            return pd.DataFrame(trait_dis, index=phe_name, columns=phe_name), pd.DataFrame()
131
    else:
132
        return pd.DataFrame(), pd.DataFrame()
133
134
135
def trait_coloc(kin_info, qtl, metric, eps, p):
136
    cls = DBSCAN(eps=eps, min_samples=2, metric='precomputed')
137
    # dis = list()
138
    # coloc_res = list()
139
    # coloc_count = 0
140
    # for c in qtl.cluster.unique():
141
    #     qtl_sub = qtl.loc[qtl.cluster==c, :]
142
    #     if qtl_sub.shape[0] > 1:
143
    #         phe_name = qtl_sub.phe_name.unique()
144
    #         trait_dis = calc_image_match_score(kin_info, phe_name, metric)
145
    #         dis.append(pd.DataFrame(trait_dis, index=phe_name, columns=phe_name))
146
    #         cls.fit(trait_dis)
147
    #         cls_res = pd.Series(cls.labels_, index=phe_name).to_frame().reset_index()
148
    #         cls_res.columns = ['phe_name', 'label']
149
    #         cls_res = cls_res.loc[cls_res.label != -1, :]
150
    #         if not cls_res.empty:
151
    #             cls_count = cls_res['label'].value_counts()
152
    #             cls_res['label'] = cls_res['label'].replace(cls_count.index, np.arange(coloc_count + 1, coloc_count + 1 + cls_count.shape[0]))
153
    #             coloc_res.append(cls_res)
154
    #             coloc_count = coloc_count + cls_count.shape[0]
155
    res = Parallel(n_jobs=p)(delayed(cluster_coloc)(kin_info, qtl, c, metric, cls) for c in qtl.cluster.unique())
156
    coloc_res = [i[1] for i in res]
157
    dis = [i[0] for i in res]
158
    coloc_res = pd.concat(coloc_res, axis=0)
159
    if not coloc_res.empty:
160
        coloc_res_count = coloc_res['label'].value_counts()
161
        coloc_res['label'] = coloc_res['label'].replace(coloc_res_count.index, np.arange(1, coloc_res_count.shape[0] + 1))
162
        coloc_res = pd.merge(qtl.drop(['cluster', 'cluster_start', 'cluster_end'], axis=1), coloc_res, on='phe_name', how='left')
163
        coloc_res = coloc_res.fillna(-1)
164
        coloc_res = coloc_res.loc[coloc_res.label != -1, :]
165
    dis = pd.concat(dis)
166
    dup_index = dis.index[dis.index.duplicated()]
167
    dup_dis = dis[dis.index.duplicated(keep=False)]
168
    dis = dis[~dis.index.duplicated()]
169
    for index in dup_index:
170
        dis.loc[index, :] = dup_dis.loc[index, :].apply(lambda x:pd.Series(x[~pd.isna(x)]).min() if(pd.isna(x).sum()!=x.shape[0]) else x[0], axis=0)
171
    dis = dis.fillna(1)
172
    dis_pairwise = dis.stack().reset_index()
173
    dis_pairwise.columns = ['level_0', 'level_1', 'image_match_score']
174
    dis_pairwise = dis_pairwise.loc[dis_pairwise.level_0 != dis_pairwise.level_1, :]
175
    dis_pairwise['id'] = dis_pairwise.apply(lambda x: ';'.join(sorted([x['level_0'], x['level_1']])), axis=1)
176
    dis_pairwise = dis_pairwise.drop_duplicates(subset='id')
177
    qtl_overlap = bf.overlap(qtl[['CHR', 'qtl_start', 'qtl_end', 'SNP', 'P', 'phe_name']], qtl[['CHR', 'qtl_start', 'qtl_end', 'SNP', 'P', 'phe_name']],
178
                             cols1=['CHR', 'qtl_start', 'qtl_end'], cols2=['CHR', 'qtl_start', 'qtl_end'], how='inner', suffixes=('_1', '_2'))
179
    qtl_overlap = qtl_overlap.loc[qtl_overlap.phe_name_1 != qtl_overlap.phe_name_2, :]
180
    qtl_overlap['id'] = qtl_overlap.apply(lambda x: ';'.join(sorted([x['phe_name_1'], x['phe_name_2']])), axis=1)
181
    qtl_overlap = qtl_overlap.drop_duplicates(subset='id')
182
    coloc_pairwise_res = pd.merge(qtl_overlap, dis_pairwise, on='id')
183
    coloc_pairwise_res = coloc_pairwise_res.drop(['id', 'level_0', 'level_1'], axis=1)
184
    coloc_pairwise_res['coloc'] = 'No'
185
    coloc_pairwise_res.loc[coloc_pairwise_res['image_match_score'] <= 0.2, 'coloc'] = 'Yes'
186
    dis = dis.reindex(qtl.phe_name.unique()).reindex(qtl.phe_name.unique(), axis=1)
187
    dis = dis.fillna(1)
188
    return coloc_res, coloc_pairwise_res, dis
189