a b/downstream_analysis/utilities.R
1
library(reshape2)
2
library(ggplot2)
3
library(ComplexHeatmap)
4
library(RColorBrewer)
5
library(viridis)
6
library(grid)
7
library(circlize)
8
library(ggrepel)
9
10
#plot settings
11
source('~/plot/publication_plot_theme.R')
12
#### some functions
13
14
#order features
15
get_feature_orders=function(topic_matrix,num_topic=100){
16
17
  feature_ids=c()#ordered ids
18
  for (i in 1:num_topic){
19
    value=topic_matrix[,i]
20
    feature_ordered=rownames(topic_matrix)[order(value,decreasing = T)]
21
    feature_ids=cbind(feature_ids,feature_ordered)
22
  }
23
  colnames(feature_ids)=paste0('topic',1:num_topic)
24
  return(feature_ids)
25
}
26
27
##calculate correlation between ranks
28
cal_corr=function(version,name_index,id_index,rna_topic,adt_topic){
29
  #read all gene ids and protein ids
30
  if (strsplit(version,'/')[[1]][1]=='covid'){
31
    all_genes=read.csv('./data/covid/rna_name.csv',header=T)
32
    all_proteins=read.csv('./data/covid/protein_name.csv',header=T)
33
  } else{#nips
34
    all_genes=read.csv('./data/rna_name.csv',header=T)
35
    all_proteins=read.csv('./protein_name_complete.csv',header=T)
36
  }
37
38
  #get corresponding ids based on gene name & protein name
39
  gene_name_id=c()
40
  for (i in rownames(rna_topic)){
41
    id=all_genes[all_genes[,name_index]==i,][,id_index]
42
    if (length(id)!=0){
43
      for (d in id){#if more than 1 ids
44
        gene_name_id=rbind(gene_name_id,c(i,d))
45
      }    
46
    }
47
  }
48
  colnames(gene_name_id)=c('name','id')
49
  gene_name_id=as.data.frame(gene_name_id)
50
51
  protein_name_id=c()
52
  for (i in rownames(adt_topic)){
53
    id=all_proteins[all_proteins[,name_index]==i,][,id_index]
54
    if (length(id)!=0){
55
    protein_name_id=rbind(protein_name_id,c(i,id))}
56
  }
57
  colnames(protein_name_id)=c('name','id')
58
  protein_name_id=as.data.frame(protein_name_id)
59
60
  ### select common ids 
61
  common_ids=intersect(gene_name_id$id,protein_name_id$id)
62
  
63
  ## get corresponding names
64
  common_pro=c()
65
  for (i in common_ids){
66
    common_pro=c(common_pro,protein_name_id[protein_name_id$id==i,]$name)
67
  }
68
  
69
  common_genes=c()
70
  for (i in common_ids){
71
    common_genes=c(common_genes,gene_name_id[gene_name_id$id==i,]$name)
72
  }
73
  
74
  ##make sure the same order
75
  protein_topic_sub=adt_topic[common_pro,]
76
  rna_topic_sub=rna_topic[common_genes,]
77
78
  ## under each topic, get rank, calculate correlation for each topic
79
  rank_gene=apply(rna_topic_sub,2,rank)#column
80
  rank_pro=apply(protein_topic_sub,2,rank)#column
81
  corr=c()
82
  p=c()#wilcox paired test
83
84
  for (i in 1:num_topic){
85
    x=rank_gene[,i]
86
    y=rank_pro[,i]
87
    c=cor(x,y)
88
    corr=c(corr,c)
89
    
90
    w=wilcox.test(x, y, paired = TRUE, alternative = "two.sided")
91
    p=c(p,w$p.value)
92
  }
93
  p_adjust=p.adjust(p,'BH')
94
95
  plot_corr=data.frame(cor=corr,p_value=p,p_value_adj=p_adjust)
96
97
  return(plot_corr)
98
}
99
100
#plot heatmap for top genes and protein
101
plot_top_feature_in_selected_topic=function(type,top_feature_num,
102
                                            all_ids,topic_matrix,
103
                                            selected_topic,
104
                                            save_path,name='',
105
                                            width=5,height=10){
106
    top=all_ids[1:top_feature_num,selected_topic]
107
    all_names=melt(top)[,3]#all top gene/protein names
108
109
    m=c()#values
110
    for ( i in selected_topic){
111
      m=cbind(m,topic_matrix[all_names,i])
112
    }
113
114
    rownames(m)=all_names
115
    colnames(m)=as.character(selected_topic)
116
117
    m_scale=apply(m,2,function(x) x/max(abs(x)))#scale by column(within each topic)
118
    m_plot=melt(m_scale)
119
    m_plot$Var2=factor(m_plot$Var2,levels=unique(m_plot$Var2))
120
    plot=ggplot(m_plot,aes(x=Var2,y=Var1,fill=value))+
121
         geom_tile() +
122
         scale_fill_gradient2(low = "blue", mid = "white", high = "red",midpoint = 0)+
123
         labs(title = "",x='topic',y=type) +
124
         theme_bw() +
125
         theme(axis.text.x = element_text(angle = -90, hjust = 0,size=10),
126
               axis.text.y = element_text(size=10))
127
128
    save_name=paste0(save_path,'top',top_feature_num,type,name,'.png')
129
    ggsave(filename=save_name,plot=plot,dpi=320,width=width,height=height)
130
}
131
132
### plot sample by topic
133
plot_sample_by_topic=function(plot_cell_topic,version,coll,num_select_topic,save_name,is_cluster=TRUE){
134
# Create the heatmap annotation
135
  if (version=='covid'){
136
    ha <- HeatmapAnnotation(
137
      cellType=plot_cell_topic$initial_clustering,
138
      severity=plot_cell_topic$Status_on_day_collection_summary,
139
      status=plot_cell_topic$Status,
140
      col = coll)
141
  } else if(version=='rna_atac'){#atac
142
    ha <- HeatmapAnnotation(
143
      cellType=plot_cell_topic$cellType,
144
      col = coll,
145
      annotation_name_gp= gpar(fontsize = 20))
146
    
147
  } else{
148
    ha <- HeatmapAnnotation(
149
      cellType1=plot_cell_topic$cellType1,
150
      cellType2=plot_cell_topic$cellType2,
151
      col = coll,
152
      annotation_name_gp= gpar(fontsize = 20))
153
  }
154
155
  #prepare plot data
156
  plot_matrix=as.matrix(t(plot_cell_topic[,1:num_select_topic]))
157
  col_fun = colorRamp2(c(min(plot_matrix), 0, max(plot_matrix)), c("blue", "white", "red"))
158
159
  # Combine the heatmap and the annotation
160
  # !!!!! check dimension of cell_topic_select, make sure all are numeric !!!!
161
  png(file=save_name,
162
      width = 1500, height = 1000,units='px',bg = "transparent",res=100)
163
  h=Heatmap(plot_matrix, col=col_fun,
164
            show_column_names = FALSE,
165
            cluster_columns = is_cluster,
166
            cluster_rows = TRUE,
167
            top_annotation = ha,
168
            row_names_gp = grid::gpar(fontsize = 20)
169
            )
170
  draw(h)
171
  dev.off()
172
}
173
174
## plot correlation
175
plot_correlation=function(corr_plot,save_name){
176
  plot_corr=ggplot(corr_plot,aes(x=x,y=y,color=color))+
177
    geom_point(size=2)+
178
    theme_Publication(base_family='Arial')+
179
    theme(panel.grid = element_blank(),
180
          panel.grid.major=element_line(colour=NA),
181
          legend.position='none',
182
          axis.title.x=element_text(size=20),
183
          axis.title.y=element_text(size=20),
184
          axis.text.y = element_text(size = 20),
185
          axis.text.x = element_text(size = 20))+
186
    labs(x='',y='correlation',title='Correlation among genes and proteins')+
187
    geom_hline(yintercept = 0,linetype = 'dashed')
188
  ggsave(filename=save_name,plot=plot_corr,dpi=320,width=10,height=3)
189
}
190
191
## plot q values
192
plot_q_value=function(plot_data,save_name){
193
  plot_q=ggplot(plot_data,aes(x=x,y=y,color=color,label=label))+
194
    geom_point(alpha=0.3,size=1.5)+
195
    #geom_text_repel()+
196
    theme_Publication(base_family='Arial')+
197
    theme(panel.grid=element_blank(),
198
          panel.grid.major=element_line(colour=NA),
199
          legend.position='none',
200
          axis.title.x=element_text(size=20),
201
          axis.title.y=element_text(size=20),
202
          axis.text.y = element_text(size = 20),
203
          axis.text.x = element_text(size = 20))+
204
    labs(x='',y='-ln(q value)',title='')+
205
    geom_hline(yintercept = 3,linetype = 'dashed',col='red')
206
  ggsave(filename=save_name,plot=plot_q,dpi=320,width=10,height=3)
207
}
208
209
## plot p values for covid topics
210
plot_topic_p_value=function(covid_p_plot,save_name){
211
  plot=ggplot(covid_p_plot,aes(x=topic,y=-log(p++exp(-20)),color=color,label=label))+
212
    geom_point(size=2)+
213
    geom_text_repel()+
214
    theme_Publication(base_family='Arial')+
215
    theme(panel.grid = element_blank(),
216
          panel.grid.major=element_line(colour=NA),
217
          legend.position='none',
218
          axis.title.x=element_text(size=20),
219
          axis.title.y=element_text(size=20),
220
          axis.text.y = element_text(size = 20),
221
          axis.text.x = element_text(size = 20))+
222
    labs(x='',y='-log(p)',title='')#+
223
    #geom_hline(yintercept = 0.05,linetype = 'dashed')
224
  ggsave(filename=save_name,plot=plot,dpi=320,width=10,height=3)
225
}
226
## Differential analysis of topic expression
227
diff_topic=function(label,cell_info,cell_topic,num_topic=100,alter='greater'){
228
  label_index=which(names(cell_info)==label)
229
  label_type=unique(cell_info[,label_index])
230
  topic_label_p=c()# topic x full_cell_type  matrix storing p values
231
  topic_label_mean=c() # topic x full_cell_type  matrix storing mean difference 
232
  for (i in label_type){
233
    group1=cell_topic[cell_info[,label_index]==i,]#topic values with label
234
    group2=cell_topic[cell_info[,label_index]!=i,]#topic values without label
235
    topic_p=c()
236
    topic_mean=c()
237
    for (j in 1:num_topic){
238
      topic_p=c(topic_p,t.test(group1[,j],group2[,j],alternative=alter)$p.value)#upregulated
239
      topic_mean=c(topic_mean,mean(group1[,j])-mean(group2[,j]))
240
    }
241
    #topic_p_adj=p.adjust(topic_p)
242
    topic_label_p=cbind(topic_label_p,topic_p)
243
    topic_label_mean=cbind(topic_label_mean,topic_mean)
244
  }
245
  #
246
  topic_label_adj=p.adjust(topic_label_p)
247
  topic_label_p=matrix(topic_label_adj,nrow=num_topic)
248
  topic_label_p=as.data.frame(topic_label_p)
249
  colnames(topic_label_p)=label_type
250
  #
251
  topic_label_mean=as.data.frame(topic_label_mean)
252
  colnames(topic_label_mean)=label_type
253
  
254
  return(list('topic_label_p'=topic_label_p,'topic_label_mean'=topic_label_mean))
255
}
256
257
## plot differntially expressed topics 
258
plot_diff_topic=function(topic_label_list,save_name,width=1000,topic_select=c(1:100)){
259
  topic_label_mean=as.matrix(topic_label_list$topic_label_mean)
260
  rownames(topic_label_mean)=1:100
261
  topic_label_p=as.matrix(topic_label_list$topic_label_p)
262
  
263
  topic_label_mean_select=topic_label_mean[topic_select,]
264
  topic_label_p_select=topic_label_p[topic_select,]
265
  
266
  col_fun = colorRamp2(c(min(topic_label_mean_select), 0, max(topic_label_mean_select)), c("blue", "white", "red"))
267
  #col_fun = colorRamp2(c(-0.6124488, 0, 0.7623149), c("blue", "white", "red"))#confounder same range
268
  
269
  png(file=save_name,width = width, height = 1000,units='px',bg = "transparent",res=120)
270
  h=Heatmap(topic_label_mean_select, col=col_fun,cluster_rows = F,cluster_columns = T,
271
            row_names_gp = gpar(fontsize = 20),column_names_gp = gpar(fontsize = 20),
272
          cell_fun = function(j, i, x, y, w, h, fill) {
273
            if(topic_label_p_select[i, j] < 0.001) {
274
              gb = textGrob("*")
275
              gb_w = convertWidth(grobWidth(gb), "mm")
276
              gb_h = convertHeight(grobHeight(gb), "mm")
277
              grid.text("*", x, y - gb_h*0.5 + gb_w*0.4,gp = gpar(fontsize = 20))
278
            } 
279
          }
280
          )
281
  draw(h)
282
  dev.off()
283
}