Switch to side-by-side view

--- a
+++ b/downstream_analysis/utilities.R
@@ -0,0 +1,283 @@
+library(reshape2)
+library(ggplot2)
+library(ComplexHeatmap)
+library(RColorBrewer)
+library(viridis)
+library(grid)
+library(circlize)
+library(ggrepel)
+
+#plot settings
+source('~/plot/publication_plot_theme.R')
+#### some functions
+
+#order features
+get_feature_orders=function(topic_matrix,num_topic=100){
+
+  feature_ids=c()#ordered ids
+  for (i in 1:num_topic){
+    value=topic_matrix[,i]
+    feature_ordered=rownames(topic_matrix)[order(value,decreasing = T)]
+    feature_ids=cbind(feature_ids,feature_ordered)
+  }
+  colnames(feature_ids)=paste0('topic',1:num_topic)
+  return(feature_ids)
+}
+
+##calculate correlation between ranks
+cal_corr=function(version,name_index,id_index,rna_topic,adt_topic){
+  #read all gene ids and protein ids
+  if (strsplit(version,'/')[[1]][1]=='covid'){
+    all_genes=read.csv('./data/covid/rna_name.csv',header=T)
+    all_proteins=read.csv('./data/covid/protein_name.csv',header=T)
+  } else{#nips
+    all_genes=read.csv('./data/rna_name.csv',header=T)
+    all_proteins=read.csv('./protein_name_complete.csv',header=T)
+  }
+
+  #get corresponding ids based on gene name & protein name
+  gene_name_id=c()
+  for (i in rownames(rna_topic)){
+    id=all_genes[all_genes[,name_index]==i,][,id_index]
+    if (length(id)!=0){
+      for (d in id){#if more than 1 ids
+        gene_name_id=rbind(gene_name_id,c(i,d))
+      }    
+    }
+  }
+  colnames(gene_name_id)=c('name','id')
+  gene_name_id=as.data.frame(gene_name_id)
+
+  protein_name_id=c()
+  for (i in rownames(adt_topic)){
+    id=all_proteins[all_proteins[,name_index]==i,][,id_index]
+    if (length(id)!=0){
+    protein_name_id=rbind(protein_name_id,c(i,id))}
+  }
+  colnames(protein_name_id)=c('name','id')
+  protein_name_id=as.data.frame(protein_name_id)
+
+  ### select common ids 
+  common_ids=intersect(gene_name_id$id,protein_name_id$id)
+  
+  ## get corresponding names
+  common_pro=c()
+  for (i in common_ids){
+    common_pro=c(common_pro,protein_name_id[protein_name_id$id==i,]$name)
+  }
+  
+  common_genes=c()
+  for (i in common_ids){
+    common_genes=c(common_genes,gene_name_id[gene_name_id$id==i,]$name)
+  }
+  
+  ##make sure the same order
+  protein_topic_sub=adt_topic[common_pro,]
+  rna_topic_sub=rna_topic[common_genes,]
+
+  ## under each topic, get rank, calculate correlation for each topic
+  rank_gene=apply(rna_topic_sub,2,rank)#column
+  rank_pro=apply(protein_topic_sub,2,rank)#column
+  corr=c()
+  p=c()#wilcox paired test
+
+  for (i in 1:num_topic){
+    x=rank_gene[,i]
+    y=rank_pro[,i]
+    c=cor(x,y)
+    corr=c(corr,c)
+    
+    w=wilcox.test(x, y, paired = TRUE, alternative = "two.sided")
+    p=c(p,w$p.value)
+  }
+  p_adjust=p.adjust(p,'BH')
+
+  plot_corr=data.frame(cor=corr,p_value=p,p_value_adj=p_adjust)
+
+  return(plot_corr)
+}
+
+#plot heatmap for top genes and protein
+plot_top_feature_in_selected_topic=function(type,top_feature_num,
+                                            all_ids,topic_matrix,
+                                            selected_topic,
+                                            save_path,name='',
+                                            width=5,height=10){
+    top=all_ids[1:top_feature_num,selected_topic]
+    all_names=melt(top)[,3]#all top gene/protein names
+
+    m=c()#values
+    for ( i in selected_topic){
+      m=cbind(m,topic_matrix[all_names,i])
+    }
+
+    rownames(m)=all_names
+    colnames(m)=as.character(selected_topic)
+
+    m_scale=apply(m,2,function(x) x/max(abs(x)))#scale by column(within each topic)
+    m_plot=melt(m_scale)
+    m_plot$Var2=factor(m_plot$Var2,levels=unique(m_plot$Var2))
+    plot=ggplot(m_plot,aes(x=Var2,y=Var1,fill=value))+
+         geom_tile() +
+         scale_fill_gradient2(low = "blue", mid = "white", high = "red",midpoint = 0)+
+         labs(title = "",x='topic',y=type) +
+         theme_bw() +
+         theme(axis.text.x = element_text(angle = -90, hjust = 0,size=10),
+               axis.text.y = element_text(size=10))
+
+    save_name=paste0(save_path,'top',top_feature_num,type,name,'.png')
+    ggsave(filename=save_name,plot=plot,dpi=320,width=width,height=height)
+}
+
+### plot sample by topic
+plot_sample_by_topic=function(plot_cell_topic,version,coll,num_select_topic,save_name,is_cluster=TRUE){
+# Create the heatmap annotation
+  if (version=='covid'){
+    ha <- HeatmapAnnotation(
+      cellType=plot_cell_topic$initial_clustering,
+      severity=plot_cell_topic$Status_on_day_collection_summary,
+      status=plot_cell_topic$Status,
+      col = coll)
+  } else if(version=='rna_atac'){#atac
+    ha <- HeatmapAnnotation(
+      cellType=plot_cell_topic$cellType,
+      col = coll,
+      annotation_name_gp= gpar(fontsize = 20))
+    
+  } else{
+    ha <- HeatmapAnnotation(
+      cellType1=plot_cell_topic$cellType1,
+      cellType2=plot_cell_topic$cellType2,
+      col = coll,
+      annotation_name_gp= gpar(fontsize = 20))
+  }
+
+  #prepare plot data
+  plot_matrix=as.matrix(t(plot_cell_topic[,1:num_select_topic]))
+  col_fun = colorRamp2(c(min(plot_matrix), 0, max(plot_matrix)), c("blue", "white", "red"))
+
+  # Combine the heatmap and the annotation
+  # !!!!! check dimension of cell_topic_select, make sure all are numeric !!!!
+  png(file=save_name,
+      width = 1500, height = 1000,units='px',bg = "transparent",res=100)
+  h=Heatmap(plot_matrix, col=col_fun,
+            show_column_names = FALSE,
+            cluster_columns = is_cluster,
+            cluster_rows = TRUE,
+            top_annotation = ha,
+            row_names_gp = grid::gpar(fontsize = 20)
+            )
+  draw(h)
+  dev.off()
+}
+
+## plot correlation
+plot_correlation=function(corr_plot,save_name){
+  plot_corr=ggplot(corr_plot,aes(x=x,y=y,color=color))+
+    geom_point(size=2)+
+    theme_Publication(base_family='Arial')+
+    theme(panel.grid = element_blank(),
+          panel.grid.major=element_line(colour=NA),
+          legend.position='none',
+          axis.title.x=element_text(size=20),
+          axis.title.y=element_text(size=20),
+          axis.text.y = element_text(size = 20),
+          axis.text.x = element_text(size = 20))+
+    labs(x='',y='correlation',title='Correlation among genes and proteins')+
+    geom_hline(yintercept = 0,linetype = 'dashed')
+  ggsave(filename=save_name,plot=plot_corr,dpi=320,width=10,height=3)
+}
+
+## plot q values
+plot_q_value=function(plot_data,save_name){
+  plot_q=ggplot(plot_data,aes(x=x,y=y,color=color,label=label))+
+    geom_point(alpha=0.3,size=1.5)+
+    #geom_text_repel()+
+    theme_Publication(base_family='Arial')+
+    theme(panel.grid=element_blank(),
+          panel.grid.major=element_line(colour=NA),
+          legend.position='none',
+          axis.title.x=element_text(size=20),
+          axis.title.y=element_text(size=20),
+          axis.text.y = element_text(size = 20),
+          axis.text.x = element_text(size = 20))+
+    labs(x='',y='-ln(q value)',title='')+
+    geom_hline(yintercept = 3,linetype = 'dashed',col='red')
+  ggsave(filename=save_name,plot=plot_q,dpi=320,width=10,height=3)
+}
+
+## plot p values for covid topics
+plot_topic_p_value=function(covid_p_plot,save_name){
+  plot=ggplot(covid_p_plot,aes(x=topic,y=-log(p++exp(-20)),color=color,label=label))+
+    geom_point(size=2)+
+    geom_text_repel()+
+    theme_Publication(base_family='Arial')+
+    theme(panel.grid = element_blank(),
+          panel.grid.major=element_line(colour=NA),
+          legend.position='none',
+          axis.title.x=element_text(size=20),
+          axis.title.y=element_text(size=20),
+          axis.text.y = element_text(size = 20),
+          axis.text.x = element_text(size = 20))+
+    labs(x='',y='-log(p)',title='')#+
+    #geom_hline(yintercept = 0.05,linetype = 'dashed')
+  ggsave(filename=save_name,plot=plot,dpi=320,width=10,height=3)
+}
+## Differential analysis of topic expression
+diff_topic=function(label,cell_info,cell_topic,num_topic=100,alter='greater'){
+  label_index=which(names(cell_info)==label)
+  label_type=unique(cell_info[,label_index])
+  topic_label_p=c()# topic x full_cell_type  matrix storing p values
+  topic_label_mean=c() # topic x full_cell_type  matrix storing mean difference 
+  for (i in label_type){
+    group1=cell_topic[cell_info[,label_index]==i,]#topic values with label
+    group2=cell_topic[cell_info[,label_index]!=i,]#topic values without label
+    topic_p=c()
+    topic_mean=c()
+    for (j in 1:num_topic){
+      topic_p=c(topic_p,t.test(group1[,j],group2[,j],alternative=alter)$p.value)#upregulated
+      topic_mean=c(topic_mean,mean(group1[,j])-mean(group2[,j]))
+    }
+    #topic_p_adj=p.adjust(topic_p)
+    topic_label_p=cbind(topic_label_p,topic_p)
+    topic_label_mean=cbind(topic_label_mean,topic_mean)
+  }
+  #
+  topic_label_adj=p.adjust(topic_label_p)
+  topic_label_p=matrix(topic_label_adj,nrow=num_topic)
+  topic_label_p=as.data.frame(topic_label_p)
+  colnames(topic_label_p)=label_type
+  #
+  topic_label_mean=as.data.frame(topic_label_mean)
+  colnames(topic_label_mean)=label_type
+  
+  return(list('topic_label_p'=topic_label_p,'topic_label_mean'=topic_label_mean))
+}
+
+## plot differntially expressed topics 
+plot_diff_topic=function(topic_label_list,save_name,width=1000,topic_select=c(1:100)){
+  topic_label_mean=as.matrix(topic_label_list$topic_label_mean)
+  rownames(topic_label_mean)=1:100
+  topic_label_p=as.matrix(topic_label_list$topic_label_p)
+  
+  topic_label_mean_select=topic_label_mean[topic_select,]
+  topic_label_p_select=topic_label_p[topic_select,]
+  
+  col_fun = colorRamp2(c(min(topic_label_mean_select), 0, max(topic_label_mean_select)), c("blue", "white", "red"))
+  #col_fun = colorRamp2(c(-0.6124488, 0, 0.7623149), c("blue", "white", "red"))#confounder same range
+  
+  png(file=save_name,width = width, height = 1000,units='px',bg = "transparent",res=120)
+  h=Heatmap(topic_label_mean_select, col=col_fun,cluster_rows = F,cluster_columns = T,
+            row_names_gp = gpar(fontsize = 20),column_names_gp = gpar(fontsize = 20),
+          cell_fun = function(j, i, x, y, w, h, fill) {
+            if(topic_label_p_select[i, j] < 0.001) {
+              gb = textGrob("*")
+              gb_w = convertWidth(grobWidth(gb), "mm")
+              gb_h = convertHeight(grobHeight(gb), "mm")
+              grid.text("*", x, y - gb_h*0.5 + gb_w*0.4,gp = gpar(fontsize = 20))
+            } 
+          }
+          )
+  draw(h)
+  dev.off()
+}