--- a +++ b/modas/mr.py @@ -0,0 +1,226 @@ +import pandas as pd +import numpy as np +import pyranges as pr +from sklearn.linear_model import LinearRegression +from pandas_plink import read_plink1_bin +from scipy.stats import chi2 +import modas.multiprocess as mp +import subprocess +import warnings +import shutil +import glob +import sys +import os +import re + + +utils_path = subprocess.check_output('locate modas/utils', shell=True, text=True, encoding='utf-8') +#utils_path = '/'.join(re.search('\n(.*site-packages.*)\n', utils_path).group(1).split('/')[:-1]) +utils_path = re.search('\n(.*site-packages.*)\n', utils_path).group(1) +if not utils_path.endswith('utils'): + utils_path = '/'.join(utils_path.split('/')[:-1]) + + +def lm_res(y, X): + idx = np.isnan(X) | np.isnan(y) + X = X[~idx].reshape(-1, 1) + y = y[~idx] + lm = LinearRegression().fit(X, y) + sigma2 = np.sum((y - lm.predict(X))**2)/(X.shape[0]-1) + se = np.sqrt(np.diag(np.linalg.pinv(np.dot(X.T, X))))[-1] * sigma2 + effect = lm.coef_[-1] + rsq = lm.score(X, y) + return pd.Series(dict(zip(['effect', 'se', 'rsq'], [effect, se, rsq]))) + + +def MR(mTrait, pTrait, g, pvalue_cutoff): + p = pd.concat([mTrait.to_frame(), pTrait], axis=1) + snp_lm = p.apply(lm_res, args=[g.values]) + mTrait_lm = p.apply(lm_res, args=[mTrait.values]) + pTrait_mTrait = snp_lm.loc['effect', :][1:] / snp_lm.loc['effect', :][0] + var_upper = pTrait.var() * (1 - mTrait_lm.loc['rsq', :][1:]) + var_down = mTrait.shape[0] * mTrait.var() * snp_lm.loc['rsq', :][0] + var = var_upper / var_down + TMR = pTrait_mTrait**2 / var + pvalue = 1 - chi2.cdf(TMR, 1) + MR_res = pd.DataFrame(dict(zip(['snp', 'mTrait', 'pTrait', 'effect', 'TMR', 'pvalue'], + [[g.name]*pTrait.shape[1], [mTrait.name]*pTrait.shape[1], pTrait.columns, pTrait_mTrait, TMR, pvalue]))) + MR_res = MR_res.loc[(MR_res.pvalue<=pvalue_cutoff) & (MR_res.mTrait!=MR_res.pTrait),:] + return MR_res + + +def MR_parallel(mTrait_qtl, mTrait, pTrait, geno, threads, pvalue_cutoff): + args = list() + for index, row in mTrait_qtl.iterrows(): + rs = row['SNP'] + mTrait_name = row['phe_name'] + args.append((mTrait.loc[:, mTrait_name], pTrait, geno.loc[:, rs], pvalue_cutoff)) + res = mp.parallel(MR, args, threads) + res = pd.concat([i for i in res]) + return res + + +def var_bXY(bzx, bzx_se, bzy, bzy_se): + varbXY = bzx_se**2*bzy**2/bzx**4 + bzy_se**2/bzx**2 + return varbXY + + +def generate_geno_batch(mTrait_qtl, mTrait, pTrait, geno, threads, bed_dir, rs_dir): + if os.path.exists(bed_dir): + shutil.rmtree(bed_dir) + os.mkdir(bed_dir) + if os.path.exists(rs_dir): + shutil.rmtree(rs_dir) + os.mkdir(rs_dir) + plink_extract = utils_path + '/plink -bfile {} -extract {} --make-bed -out {}' + geno_batch = list() + for mTrait_name in mTrait_qtl.phe_name.unique(): + out_name = bed_dir.strip('/') + '/' + mTrait_name + rs = mTrait_qtl.loc[mTrait_qtl.phe_name == mTrait_name, 'SNP'] + rs_name = rs_dir.strip('/') + '/' + '_'.join([mTrait_name,'rs.txt']) + pd.Series(rs).to_frame().to_csv(rs_name, index=False, header=None) + geno_batch.append((plink_extract.format(geno, rs_name, out_name),)) + out_name = bed_dir.strip('/') + '/pTrait' + rs_name = rs_dir.strip('/') + '/pTrait_rs.txt' + mTrait_qtl['SNP'].to_frame().to_csv(rs_name, index=False, header=None) + geno_batch.append((plink_extract.format(geno, rs_name, out_name),)) + mp.parallel(mp.run, geno_batch, threads) + for fn in glob.glob(bed_dir.strip('/')+'/*fam'): + fam = pd.read_csv(fn, sep=' ', header=None) + mTrait_name = fn.split('/')[-1].replace('.fam', '') + if mTrait_name == 'pTrait': + pTrait = pTrait.reindex(fam[0]) + fam.index = fam[0] + fam = pd.concat([fam, pTrait], axis=1) + else: + fam.loc[:, 5] = mTrait.loc[:, mTrait_name].reindex(fam[0]).values + fam.to_csv(fn, index=False, header=None, sep=' ', na_rep='NA') + + +def calc_MLM_effect(bed_dir, pTrait, threads, geno): + args = list() + geno_prefix = geno.split('/')[-1] + fam = pd.read_csv(geno + '.fam', sep=r'\s+', header=None) + fam[5] = 1 + fam.to_csv(geno_prefix + '.link.fam', sep='\t', na_rep='NA', header=None, index=False) + if os.path.exists(geno_prefix + '.link.bed'): + os.remove(geno_prefix + '.link.bed') + if os.path.exists(geno_prefix + '.link.bim'): + os.remove(geno_prefix + '.link.bim') + os.symlink(geno + '.bed', geno_prefix + '.link.bed') + os.symlink(geno + '.bim', geno_prefix + '.link.bim') + related_matrix_cmd = utils_path + '/gemma -bfile {0}.link -gk 1 -o {1}'.format(geno_prefix, geno_prefix) + s = mp.run(related_matrix_cmd) + if s != 0: + return None + gemma_cmd_mTrait = utils_path + '/gemma -bfile {0} -k ./output/{1}.cXX.txt -lmm -n 1 -o {2}' + gemma_cmd_pTrait = utils_path + '/gemma -bfile {0} -k ./output/{1}.cXX.txt -lmm -n {2} -o {3}' + for i in glob.glob(bed_dir + '/*.bed'): + i = i.replace('.bed', '') + if i.split('/')[-1] != 'pTrait': + prefix = i.split('/')[-1] + args.append((gemma_cmd_mTrait.format(i, geno_prefix, 'mTrait_' + prefix),)) + else: + for _, pTrait_name in enumerate(pTrait.columns): + args.append((gemma_cmd_pTrait.format(i, geno_prefix, _ + 2, 'pTrait_' + pTrait_name),)) + s = mp.parallel(mp.run, args, threads) + os.remove(geno_prefix + '.link.bed') + os.remove(geno_prefix + '.link.bim') + os.remove(geno_prefix + '.link.fam') + return s + + +def get_MLM_effect(fn): + assoc = pd.read_csv(fn, sep='\t') + assoc.index = assoc['rs'] + return assoc[['beta', 'se']] + + +def get_MLM_effect_parallell(assoc_dir, mTrait, pTrait, threads): + mTrait_effect = pd.DataFrame() + args = [] + #pTrait_name = [] + # for fn in glob.glob(assoc_dir.strip('/') + '/mTrait*.assoc.txt'): + # mTrait_name = fn.split('/')[-1].split('_')[-1].replace('.assoc.txt', '') + # assoc = pd.read_csv(fn, sep='\t') + # assoc.index = mTrait_name+';' + assoc['rs'] + # mTrait_effect = pd.concat([mTrait_effect, assoc[['beta', 'se']]]) + # for fn in glob.glob(assoc_dir.strip('/') + '/pTrait*assoc.txt'): + # pTrait_name.append(fn.split('/')[-1].split('_')[-1].replace('.assoc.txt', '')) + # args.append((fn,)) + for mTrait_name in mTrait.columns: + fn = assoc_dir.strip('/') + '/mTrait_' + mTrait_name + '.assoc.txt' + assoc = pd.read_csv(fn, sep='\t') + assoc.index = mTrait_name+';' + assoc['rs'] + mTrait_effect = pd.concat([mTrait_effect, assoc[['beta', 'se']]]) + for pTrait_name in pTrait.columns: + fn = assoc_dir.strip('/') + '/pTrait_' + pTrait_name + '.assoc.txt' + args.append((fn,)) + pTrait_res = mp.parallel(get_MLM_effect, args, threads) + pTrait_effect = pd.concat([i['beta'] for i in pTrait_res], axis=1) + pTrait_effect.columns = pTrait.columns + pTrait_se = pd.concat([i['se'] for i in pTrait_res], axis=1) + pTrait_se.columns = pTrait.columns + return mTrait_effect, pTrait_effect, pTrait_se + + +def MR_MLM(mTrait_effect_snp, pTrait_effect_snp, pTrait_se_snp, pvalue_cutoff): + mTrait_name, rs = mTrait_effect_snp.name.split(';') + bxy = pTrait_effect_snp / mTrait_effect_snp['beta'] + varbXY = var_bXY(mTrait_effect_snp['beta'], mTrait_effect_snp['se'], pTrait_effect_snp, pTrait_se_snp) + TMR = bxy**2 / varbXY + pvalue = 1 - chi2.cdf(TMR, 1) + MR_res = pd.DataFrame(dict(zip(['snp', 'mTrait', 'pTrait', 'effect', 'TMR', 'pvalue'], + [[rs] * pTrait_effect_snp.shape[0], [mTrait_name] * pTrait_effect_snp.shape[0], + pTrait_effect_snp.index, bxy, TMR, pvalue]))) + MR_res = MR_res.loc[(MR_res.pvalue<=pvalue_cutoff) & (MR_res.mTrait!=MR_res.pTrait), :] + return MR_res + + +def MR_MLM_parallel(mTrait_qtl, mTrait_effect, pTrait_effect, pTrait_se, threads, pvalue_cutoff): + args = [] + for index, row in mTrait_qtl.iterrows(): + mTrait_name = row['phe_name'] + rs = row['SNP'] + args.append((mTrait_effect.loc[';'.join([mTrait_name, rs]),:], pTrait_effect.loc[rs, :], pTrait_se.loc[rs, :], pvalue_cutoff)) + res = mp.parallel(MR_MLM, args, threads) + res = pd.concat([i for i in res]) + return res + + +def edge_weight(qtl, MR_res): + MR_res = MR_res.loc[MR_res.mTrait!=MR_res.pTrait,:] + MR_res = MR_res.sort_values(by='pvalue') + qtl_peak_pos = qtl[['CHR', 'phe_name']] + qtl_peak_pos.loc[:, 'start'] = qtl['SNP'].apply(lambda x: int(x.split('_')[-1]) - 1) + qtl_peak_pos.loc[:, 'end'] = qtl['SNP'].apply(lambda x: int(x.split('_')[-1]) + 1) + qtl_peak_pos = qtl_peak_pos[['CHR', 'start', 'end', 'phe_name']] + qtl_peak_pos.columns = ['Chromosome', 'Start', 'End', 'phe_name'] + qtl_peak_range = qtl[['CHR', 'phe_name']] + qtl_peak_range.loc[:, 'start'] = qtl['SNP'].apply(lambda x: int(x.split('_')[-1]) - 1000000) + qtl_peak_range.loc[:, 'end'] = qtl['SNP'].apply(lambda x: int(x.split('_')[-1]) + 1000000) + qtl_peak_range = qtl_peak_range[['CHR', 'start', 'end', 'phe_name']] + qtl_peak_range.columns = ['Chromosome', 'Start', 'End', 'phe_name'] + qtl_peak_pos = pr.PyRanges(qtl_peak_pos) + qtl_peak_range = pr.PyRanges(qtl_peak_range) + coloc = qtl_peak_pos.join(qtl_peak_range) + + res = pd.DataFrame() + for k in sorted(coloc.dfs.keys()): + res = pd.concat([res, coloc.dfs[k]]) + res = res.loc[res.phe_name != res.phe_name_b, :] + mr_coloc = pd.merge(MR_res, res, left_on=['mTrait', 'pTrait'], right_on=['phe_name', 'phe_name_b']) + #coloc = coloc.sort_values(by='pvalue') + mr_coloc_count = mr_coloc.pvalue.value_counts().sort_index() + mr_count = MR_res.pvalue.value_counts().sort_index().cumsum() + weight = list() + for p in MR_res.pvalue.values: + coloc_count = mr_coloc_count[mr_coloc_count.index <= p].sum() + weight.append(coloc_count / mr_count[p]) + MR_res.loc[:, 'weight'] = weight + edgeweight = MR_res[['mTrait', 'pTrait', 'weight']] + edgeweight = edgeweight.loc[edgeweight.weight >= 0.2, :] + edgeweight = edgeweight.sort_values(by=['mTrait', 'pTrait', 'weight'], ascending=False).drop_duplicates(subset=['mTrait', 'pTrait']) + return edgeweight + +