Diff of /modas/mr.py [000000] .. [a43cea]

Switch to side-by-side view

--- a
+++ b/modas/mr.py
@@ -0,0 +1,226 @@
+import pandas as pd
+import numpy as np
+import pyranges as pr
+from sklearn.linear_model import LinearRegression
+from pandas_plink import read_plink1_bin
+from scipy.stats import chi2
+import modas.multiprocess as mp
+import subprocess
+import warnings
+import shutil
+import glob
+import sys
+import os
+import re
+
+
+utils_path = subprocess.check_output('locate modas/utils', shell=True, text=True, encoding='utf-8')
+#utils_path = '/'.join(re.search('\n(.*site-packages.*)\n', utils_path).group(1).split('/')[:-1])
+utils_path = re.search('\n(.*site-packages.*)\n', utils_path).group(1)
+if not utils_path.endswith('utils'):
+    utils_path = '/'.join(utils_path.split('/')[:-1])
+
+
+def lm_res(y, X):
+    idx = np.isnan(X) | np.isnan(y)
+    X = X[~idx].reshape(-1, 1)
+    y = y[~idx]
+    lm = LinearRegression().fit(X, y)
+    sigma2 = np.sum((y - lm.predict(X))**2)/(X.shape[0]-1)
+    se = np.sqrt(np.diag(np.linalg.pinv(np.dot(X.T, X))))[-1] * sigma2
+    effect = lm.coef_[-1]
+    rsq = lm.score(X, y)
+    return pd.Series(dict(zip(['effect', 'se', 'rsq'], [effect, se, rsq])))
+
+
+def MR(mTrait, pTrait, g, pvalue_cutoff):
+    p = pd.concat([mTrait.to_frame(), pTrait], axis=1)
+    snp_lm = p.apply(lm_res, args=[g.values])
+    mTrait_lm = p.apply(lm_res, args=[mTrait.values])
+    pTrait_mTrait = snp_lm.loc['effect', :][1:] / snp_lm.loc['effect', :][0]
+    var_upper = pTrait.var() * (1 - mTrait_lm.loc['rsq', :][1:])
+    var_down = mTrait.shape[0] * mTrait.var() * snp_lm.loc['rsq', :][0]
+    var = var_upper / var_down
+    TMR = pTrait_mTrait**2 / var
+    pvalue = 1 - chi2.cdf(TMR, 1)
+    MR_res = pd.DataFrame(dict(zip(['snp', 'mTrait', 'pTrait', 'effect', 'TMR', 'pvalue'],
+                        [[g.name]*pTrait.shape[1], [mTrait.name]*pTrait.shape[1], pTrait.columns, pTrait_mTrait, TMR, pvalue])))
+    MR_res = MR_res.loc[(MR_res.pvalue<=pvalue_cutoff) & (MR_res.mTrait!=MR_res.pTrait),:]
+    return MR_res
+
+
+def MR_parallel(mTrait_qtl, mTrait, pTrait, geno, threads, pvalue_cutoff):
+    args = list()
+    for index, row in mTrait_qtl.iterrows():
+        rs = row['SNP']
+        mTrait_name = row['phe_name']
+        args.append((mTrait.loc[:, mTrait_name], pTrait, geno.loc[:, rs], pvalue_cutoff))
+    res = mp.parallel(MR, args, threads)
+    res = pd.concat([i for i in res])
+    return res
+
+
+def var_bXY(bzx, bzx_se, bzy, bzy_se):
+    varbXY = bzx_se**2*bzy**2/bzx**4 + bzy_se**2/bzx**2
+    return varbXY
+
+
+def generate_geno_batch(mTrait_qtl, mTrait, pTrait, geno, threads, bed_dir, rs_dir):
+    if os.path.exists(bed_dir):
+        shutil.rmtree(bed_dir)
+    os.mkdir(bed_dir)
+    if os.path.exists(rs_dir):
+        shutil.rmtree(rs_dir)
+    os.mkdir(rs_dir)
+    plink_extract = utils_path + '/plink -bfile {} -extract {} --make-bed -out {}'
+    geno_batch = list()
+    for mTrait_name in mTrait_qtl.phe_name.unique():
+        out_name = bed_dir.strip('/') + '/' + mTrait_name
+        rs = mTrait_qtl.loc[mTrait_qtl.phe_name == mTrait_name, 'SNP']
+        rs_name = rs_dir.strip('/') + '/' + '_'.join([mTrait_name,'rs.txt'])
+        pd.Series(rs).to_frame().to_csv(rs_name, index=False, header=None)
+        geno_batch.append((plink_extract.format(geno, rs_name, out_name),))
+    out_name = bed_dir.strip('/') + '/pTrait'
+    rs_name = rs_dir.strip('/') + '/pTrait_rs.txt'
+    mTrait_qtl['SNP'].to_frame().to_csv(rs_name, index=False, header=None)
+    geno_batch.append((plink_extract.format(geno, rs_name, out_name),))
+    mp.parallel(mp.run, geno_batch, threads)
+    for fn in glob.glob(bed_dir.strip('/')+'/*fam'):
+        fam = pd.read_csv(fn, sep=' ', header=None)
+        mTrait_name = fn.split('/')[-1].replace('.fam', '')
+        if mTrait_name == 'pTrait':
+            pTrait = pTrait.reindex(fam[0])
+            fam.index = fam[0]
+            fam = pd.concat([fam, pTrait], axis=1)
+        else:
+            fam.loc[:, 5] = mTrait.loc[:, mTrait_name].reindex(fam[0]).values
+        fam.to_csv(fn, index=False, header=None, sep=' ', na_rep='NA')
+
+
+def calc_MLM_effect(bed_dir, pTrait, threads, geno):
+    args = list()
+    geno_prefix = geno.split('/')[-1]
+    fam = pd.read_csv(geno + '.fam', sep=r'\s+', header=None)
+    fam[5] = 1
+    fam.to_csv(geno_prefix + '.link.fam', sep='\t', na_rep='NA', header=None, index=False)
+    if os.path.exists(geno_prefix + '.link.bed'):
+        os.remove(geno_prefix + '.link.bed')
+    if os.path.exists(geno_prefix + '.link.bim'):
+        os.remove(geno_prefix + '.link.bim')
+    os.symlink(geno + '.bed', geno_prefix + '.link.bed')
+    os.symlink(geno + '.bim', geno_prefix + '.link.bim')
+    related_matrix_cmd = utils_path + '/gemma -bfile {0}.link -gk 1 -o {1}'.format(geno_prefix, geno_prefix)
+    s = mp.run(related_matrix_cmd)
+    if s != 0:
+        return None
+    gemma_cmd_mTrait = utils_path + '/gemma -bfile {0} -k ./output/{1}.cXX.txt -lmm -n 1 -o {2}'
+    gemma_cmd_pTrait = utils_path + '/gemma -bfile {0} -k ./output/{1}.cXX.txt -lmm -n {2} -o {3}'
+    for i in glob.glob(bed_dir + '/*.bed'):
+        i = i.replace('.bed', '')
+        if i.split('/')[-1] != 'pTrait':
+            prefix = i.split('/')[-1]
+            args.append((gemma_cmd_mTrait.format(i, geno_prefix, 'mTrait_' + prefix),))
+        else:
+            for _, pTrait_name in enumerate(pTrait.columns):
+                args.append((gemma_cmd_pTrait.format(i, geno_prefix, _ + 2, 'pTrait_' + pTrait_name),))
+    s = mp.parallel(mp.run, args, threads)
+    os.remove(geno_prefix + '.link.bed')
+    os.remove(geno_prefix + '.link.bim')
+    os.remove(geno_prefix + '.link.fam')
+    return s
+
+
+def get_MLM_effect(fn):
+    assoc = pd.read_csv(fn, sep='\t')
+    assoc.index = assoc['rs']
+    return assoc[['beta', 'se']]
+
+
+def get_MLM_effect_parallell(assoc_dir, mTrait, pTrait, threads):
+    mTrait_effect = pd.DataFrame()
+    args = []
+    #pTrait_name = []
+    # for fn in glob.glob(assoc_dir.strip('/') + '/mTrait*.assoc.txt'):
+    #     mTrait_name = fn.split('/')[-1].split('_')[-1].replace('.assoc.txt', '')
+    #     assoc = pd.read_csv(fn, sep='\t')
+    #     assoc.index = mTrait_name+';' + assoc['rs']
+    #     mTrait_effect = pd.concat([mTrait_effect, assoc[['beta', 'se']]])
+    # for fn in glob.glob(assoc_dir.strip('/') + '/pTrait*assoc.txt'):
+    #     pTrait_name.append(fn.split('/')[-1].split('_')[-1].replace('.assoc.txt', ''))
+    #     args.append((fn,))
+    for mTrait_name in mTrait.columns:
+        fn = assoc_dir.strip('/') + '/mTrait_' + mTrait_name + '.assoc.txt'
+        assoc = pd.read_csv(fn, sep='\t')
+        assoc.index = mTrait_name+';' + assoc['rs']
+        mTrait_effect = pd.concat([mTrait_effect, assoc[['beta', 'se']]])
+    for pTrait_name in pTrait.columns:
+        fn = assoc_dir.strip('/') + '/pTrait_' + pTrait_name + '.assoc.txt'
+        args.append((fn,))
+    pTrait_res = mp.parallel(get_MLM_effect, args, threads)
+    pTrait_effect = pd.concat([i['beta'] for i in pTrait_res], axis=1)
+    pTrait_effect.columns = pTrait.columns
+    pTrait_se = pd.concat([i['se'] for i in pTrait_res], axis=1)
+    pTrait_se.columns = pTrait.columns
+    return mTrait_effect, pTrait_effect, pTrait_se
+
+
+def MR_MLM(mTrait_effect_snp, pTrait_effect_snp, pTrait_se_snp, pvalue_cutoff):
+    mTrait_name, rs = mTrait_effect_snp.name.split(';')
+    bxy = pTrait_effect_snp / mTrait_effect_snp['beta']
+    varbXY = var_bXY(mTrait_effect_snp['beta'], mTrait_effect_snp['se'], pTrait_effect_snp, pTrait_se_snp)
+    TMR = bxy**2 / varbXY
+    pvalue = 1 - chi2.cdf(TMR, 1)
+    MR_res = pd.DataFrame(dict(zip(['snp', 'mTrait', 'pTrait', 'effect', 'TMR', 'pvalue'],
+                                   [[rs] * pTrait_effect_snp.shape[0], [mTrait_name] * pTrait_effect_snp.shape[0],
+                                    pTrait_effect_snp.index, bxy, TMR, pvalue])))
+    MR_res = MR_res.loc[(MR_res.pvalue<=pvalue_cutoff) & (MR_res.mTrait!=MR_res.pTrait), :]
+    return MR_res
+
+
+def MR_MLM_parallel(mTrait_qtl, mTrait_effect, pTrait_effect, pTrait_se, threads, pvalue_cutoff):
+    args = []
+    for index, row in mTrait_qtl.iterrows():
+        mTrait_name = row['phe_name']
+        rs = row['SNP']
+        args.append((mTrait_effect.loc[';'.join([mTrait_name, rs]),:], pTrait_effect.loc[rs, :], pTrait_se.loc[rs, :], pvalue_cutoff))
+    res = mp.parallel(MR_MLM, args, threads)
+    res = pd.concat([i for i in res])
+    return res
+
+
+def edge_weight(qtl, MR_res):
+    MR_res = MR_res.loc[MR_res.mTrait!=MR_res.pTrait,:]
+    MR_res = MR_res.sort_values(by='pvalue')
+    qtl_peak_pos = qtl[['CHR', 'phe_name']]
+    qtl_peak_pos.loc[:, 'start'] = qtl['SNP'].apply(lambda x: int(x.split('_')[-1]) - 1)
+    qtl_peak_pos.loc[:, 'end'] = qtl['SNP'].apply(lambda x: int(x.split('_')[-1]) + 1)
+    qtl_peak_pos = qtl_peak_pos[['CHR', 'start', 'end', 'phe_name']]
+    qtl_peak_pos.columns = ['Chromosome', 'Start', 'End', 'phe_name']
+    qtl_peak_range = qtl[['CHR', 'phe_name']]
+    qtl_peak_range.loc[:, 'start'] = qtl['SNP'].apply(lambda x: int(x.split('_')[-1]) - 1000000)
+    qtl_peak_range.loc[:, 'end'] = qtl['SNP'].apply(lambda x: int(x.split('_')[-1]) + 1000000)
+    qtl_peak_range = qtl_peak_range[['CHR', 'start', 'end', 'phe_name']]
+    qtl_peak_range.columns = ['Chromosome', 'Start', 'End', 'phe_name']
+    qtl_peak_pos = pr.PyRanges(qtl_peak_pos)
+    qtl_peak_range = pr.PyRanges(qtl_peak_range)
+    coloc = qtl_peak_pos.join(qtl_peak_range)
+
+    res = pd.DataFrame()
+    for k in sorted(coloc.dfs.keys()):
+        res = pd.concat([res, coloc.dfs[k]])
+    res = res.loc[res.phe_name != res.phe_name_b, :]
+    mr_coloc = pd.merge(MR_res, res, left_on=['mTrait', 'pTrait'], right_on=['phe_name', 'phe_name_b'])
+    #coloc = coloc.sort_values(by='pvalue')
+    mr_coloc_count = mr_coloc.pvalue.value_counts().sort_index()
+    mr_count = MR_res.pvalue.value_counts().sort_index().cumsum()
+    weight = list()
+    for p in MR_res.pvalue.values:
+        coloc_count = mr_coloc_count[mr_coloc_count.index <= p].sum()
+        weight.append(coloc_count / mr_count[p])
+    MR_res.loc[:, 'weight'] = weight
+    edgeweight = MR_res[['mTrait', 'pTrait', 'weight']]
+    edgeweight = edgeweight.loc[edgeweight.weight >= 0.2, :]
+    edgeweight = edgeweight.sort_values(by=['mTrait', 'pTrait', 'weight'], ascending=False).drop_duplicates(subset=['mTrait', 'pTrait'])
+    return edgeweight
+
+