Switch to unified view

a b/function/reliability_diagram.R
1
reliability_diagram = function(data_ls,u, stat_type,bins,
2
                               title_ls,color_ls,title) {
3
  data_ls_len = length(data_ls)
4
  ### create bin averages
5
  for (data_i in 1:data_ls_len){
6
    colnames(data_ls[[data_i]]) = c("status","survival","pred")
7
    temp_res = reliability_datapts(data_ls[[data_i]], u = u,bins = bins, stat_type = stat_type)
8
    temp_hl = hosmer_lemeshow(data_ls[[data_i]], u = u,bins = bins, stat_type = stat_type)
9
    assign(paste("recal_bins", data_i, sep=""),temp_res)
10
    assign(paste("HL_pvalue", data_i, sep=""),temp_hl[1])
11
    assign(paste("HL_statistic", data_i, sep=""),temp_hl[2])
12
  }
13
  
14
  hl_dt = data.frame(x=numeric(0),y=numeric(0),variable= character(0), label = character(0),HLpval = numeric(0),HLstatistic =numeric(0) )
15
  
16
  for (data_i in 1:data_ls_len){
17
    temp_res = get(paste('recal_bins', data_i, sep=''))
18
    temp_res$variable <- paste("Vol.x", data_i, sep='')
19
    colnames(temp_res)[colnames(temp_res)=='V1'] = 'value'
20
    assign(paste('melt', data_i, sep=''), temp_res)
21
    hl_p = get(paste('HL_pvalue', data_i, sep=''))
22
    hl_s = get(paste('HL_statistic', data_i, sep=''))
23
    hl_dt[data_i,] = cbind(0.05,1-0.1*data_i,paste("Vol.x", data_i, sep=''),title_ls[data_i],hl_p[[1]],hl_s[[1]])
24
  }
25
  hl_dt$x = as.numeric(hl_dt$x)
26
  hl_dt$y = as.numeric(hl_dt$y)
27
  hl_dt$HLpval = as.numeric(hl_dt$HLpval)
28
  hl_dt$HLstatistic = as.numeric(hl_dt$HLstatistic)
29
  
30
  data = melt1
31
  if (data_ls_len > 1){
32
    for (data_i in 2:data_ls_len){
33
      data = rbind(data, get(paste('melt', data_i, sep='')))
34
    }
35
  }
36
  
37
  line_plot = ggplot(data, aes(x=pred,  y=obs,color=variable,group= variable))   +  
38
    geom_point(aes(shape = variable),size = 3)+
39
    geom_line(linetype = "solid")+
40
    geom_text(data = hl_dt,aes(x=x,y=y,color = variable,hjust=0,
41
                               label = paste0(label," H-L test p=",round(HLpval,2),", C-statistic=",round(HLstatistic,2))))+
42
    scale_color_manual(labels = title_ls,
43
                       values = color_ls) +
44
    scale_shape_discrete(labels = title_ls)+
45
    scale_linetype_discrete(labels = title_ls)+
46
    guides(color=guide_legend(" "),shape= guide_legend(" "),linetype=guide_legend(" ")) + 
47
    xlab(paste("Predicted Survival Probability at",title)) + ylab(paste("Observed Proportion at",title)) +
48
    geom_abline(intercept = 0, slope = 1, color="black",
49
                linetype="dashed", size=1) +
50
    lims(x=c(0,1),y=c(0,1))+
51
    theme_bw()+
52
    theme(
53
      legend.position="bottom",
54
      axis.title = element_text(size=12),
55
      axis.text = element_text(size=12),
56
      legend.text = element_text(size=15),
57
      legend.title = element_text(size=15),
58
      text = element_text(size=12)
59
    )
60
  line_plot
61
  if (!file.exists("results/LOO_calibration")) dir.create('results/LOO_calibration')
62
  ggsave(paste0("results/LOO_calibration/",title,".pdf"),device = 'pdf',width = 6,height = 6)
63
}
64
65
reliability_datapts <- function(ndata, u,bins=10, stat_type ='C') {
66
  min.pred <- min(ndata$pred)
67
  max.pred <- max(ndata$pred)
68
  min.max.diff <- max.pred - min.pred
69
  
70
  if (stat_type == 'H'){
71
    ndata = ndata[order(ndata$pred),]
72
    res = data.frame(obs= numeric(0), pred = numeric(0),obs_lower=numeric(0),obs_upper=numeric(0))
73
    split_mtx = split(ndata, cut(ndata$pred, seq(0,1,1/bins), include.lowest=TRUE))
74
    for (i in 1:length(split_mtx)){
75
      sfit = summary(survfit(Surv(survival, status) ~ 1,data=split_mtx[[i]]),times=u,extend=T)
76
      obs = sfit[['surv']]
77
      obs_upper =  sfit[['upper']]
78
      obs_lower = sfit[['lower']]
79
      pred = mean(split_mtx[[i]]$pred)
80
      if (sum(is.na(col_mean)) > 0) {
81
        next
82
      }
83
      res[i,] = c(obs,pred,obs_lower,obs_upper)
84
    }
85
  }else{
86
    ## C statistics, same number of instances in each bin
87
    mtx = ndata[order(ndata$pred),]
88
    n <- length(ndata$pred)/bins
89
    nr <- nrow(mtx)
90
    split_mtx = split(mtx, rep(1:ceiling(nr/n), each=n, length.out=nr))
91
    res = data.frame(obs= numeric(0), pred = numeric(0),obs_lower=numeric(0),obs_upper=numeric(0))
92
    for (i in 1:length(split_mtx)){
93
      sfit = summary(survfit(Surv(survival, status) ~ 1,data=split_mtx[[i]]),times=u,extend=T)
94
      obs = sfit[['surv']]
95
      obs_upper =  sfit[['upper']]
96
      obs_lower = sfit[['lower']]
97
      pred = mean(split_mtx[[i]]$pred)
98
      res[i,] = c(obs,pred,obs_lower,obs_upper)
99
    }
100
  }
101
  res
102
}
103
104
hosmer_lemeshow <- function(ndata, u,bins=10, stat_type ='C') {
105
  min.pred <- min(ndata$pred)
106
  max.pred <- max(ndata$pred)
107
  min.max.diff <- max.pred - min.pred
108
  if (stat_type == 'H'){
109
    ndata = ndata[order(ndata$pred),]
110
    res = data.frame(obs= numeric(0), pred = numeric(0),obs_lower=numeric(0),obs_upper=numeric(0))
111
    split_mtx = split(ndata, cut(ndata$pred, seq(0,1,1/bins), include.lowest=TRUE))
112
  }else{
113
    mtx = ndata[order(ndata$pred),]
114
    n <- length(ndata$pred)/bins
115
    nr <- nrow(mtx)
116
    split_mtx = split(mtx, rep(1:ceiling(nr/n), each=n, length.out=nr))
117
  }
118
  H_stat = 0
119
  for (i in 1:length(split_mtx)){
120
    sfit = summary(survfit(Surv(survival, status) ~ 1,data=split_mtx[[i]]),times=u,extend=T)
121
    obs = sfit[['surv']]
122
    exp = mean(split_mtx[[i]]$pred)
123
    obs_not = 1-obs
124
    exp_not = 1-exp
125
    
126
    if (exp == 0 || exp_not == 0){
127
      next
128
    }
129
    bin_sum = ((obs - exp)**2)/exp + ((obs_not - exp_not)**2)/exp_not
130
    
131
    H_stat = H_stat + bin_sum
132
  }
133
  PVAL = 1 - pchisq(H_stat, bins - 2)
134
  
135
  cat('PVALUE', PVAL, '\n')
136
  cat('stat', H_stat, '\n')
137
  return(c(PVAL,H_stat))
138
}
139
140
141
#### test #######
142
# data_ls = list(subset(dt_risk_LOO,!is.na(prob1)) %>% select(DFS_status,DFS,prob1) ,
143
#      subset(dt_risk_LOO,!is.na(cox1_prob1))%>% select(DFS_status,DFS,cox1_prob1))
144
# title_ls= c('Joint model', 'Cox model')
145
# color_ls = c('blue', 'red')
146
# limits=c(0,1)
147
# u=15
148
# bins = 5
149
# title = "15 Months"