a b/run_joint_model_analysis.R
1
## Run all of the analyses
2
library(JMbayes)
3
library(dplyr)
4
library("parallel")
5
library(ggplot2)
6
require(survival)
7
require(survminer)
8
library(caret)
9
require(R.utils)
10
library(reshape2)
11
12
setwd(dirname(rstudioapi::getSourceEditorContext()$path))
13
14
source("function/cross_validation.R")
15
source("function/plot_evaluation.R")
16
source("function/leaveoneout.R")
17
source("function/reliability_diagram.R")
18
19
#### ---- Run all analysese scripts ---- ####
20
21
mydata = read.csv("data/Joint_model_data.csv") %>% mutate(T_stage = as.factor(T_stage),TP53 = as.factor(TP53))
22
mydata$logAF = log(mydata$meanAF+1e-6)-log(1e-6)
23
24
mydata.id = mydata[which(!duplicated(mydata$PatientID)), ]
25
mydata.id.AD =  subset(mydata.id,PathologicalType == 'AD')
26
27
28
Tstart = 244/30 # landmarking time, using data up to plasma time point 4
29
predict_times = c(12, 15) # horizon time 
30
31
##### 5-fold cross validation for 20 times #####
32
# set.seed(888)
33
# splits=createMultiFolds(factor(mydata.id$DFS_status),k=5,times=5)
34
# splits_AD=createMultiFolds(factor(mydata.id.AD$DFS_status),k=5,times=10)
35
36
## For reproduction, we are using cross validation splits used in the paper
37
splits = readRDS('data/splits.RDS')
38
splits_AD = readRDS('data/splits.AD.RDS')
39
40
### run cross validation for models using all patients
41
42
## using multiple cores
43
cores = parallel::detectCores()
44
cl <- makeCluster(10)
45
clusterExport(cl,c("mydata","mydata.id","runCV",'Tstart','predict_times','mydata.id.AD'))
46
clusterEvalQ(cl, {library(splines);library(nlme);library("JMbayes");
47
  library(dplyr);library(survival);library("xtable");
48
  library("lattice")})
49
res_all <- parLapply(cl, splits, function(x) runCV(x,Tstart,predict_times,mydata,mydata.id))
50
res_AD <- parLapply(cl, splits_AD, function(x) runCV(x,Tstart,predict_times,mydata,mydata.id.AD))
51
stopCluster(cl)
52
53
54
# result 
55
out_all = bind_rows(res_all, .id = "column_label")
56
out_AD = bind_rows(res_AD, .id = "column_label")
57
58
## plot the CV results
59
                    
60
# CV_all_patients/JMvsCox_testing.pdf ~ Fig 4b,  CV_all_patients/JMvsCox_training.pdf ~ Supplementary Fig 10,
61
# CV_all_patients/betweenJMs.pdf ~ Supplementary Fig 15               
62
plot_evaluation('results/CV_all_patients',out_all) 
63
64
# CV_AD_patients/JMvsCox_testing.pdf,JMvsCox_training.pdf ~ Supplementary Fig 12
65
plot_evaluation('results/CV_AD_patients',out_AD)
66
67
## plot the personalized prediction ~ Fig 4c,d; Supplementary Fig 13
68
lm <- lme( logAF ~ ns(TestDate,2),data=mydata,random = ~  ns(TestDate,2) | PatientID,
69
           control = lmeControl(opt = "optim",msMaxIter =1000))
70
fit <- coxph(Surv(DFS,DFS_status)~TP53+T_stage ,data=mydata.id,x = TRUE)
71
iForm <- list(fixed = ~ 0 + TestDate + ins(TestDate, 2), random = ~ 0 + TestDate + ins(TestDate, 2),
72
              indFixed = 1:3, indRandom = 1:3)
73
final_jm = jointModelBayes(lm, fit, timeVar = "TestDate",
74
                param = "td-extra", extraForm = iForm)
75
76
if (!file.exists('personalized')) dir.create('personalized')
77
78
for (PID in unique(mydata.id$PatientID)){
79
  
80
  ND = mydata[mydata$PatientID==PID,]
81
  l = nrow(ND)
82
  Relapse_status = ifelse(ND$DFS_status[1]==1,'Relapsed','Relaspe-free')
83
  if (l<2) next
84
  
85
  png(paste0("personalized/",PID,".png"),width=(3*(l-1)+1)*100,height = 4*100,pointsize = 12,bg='transparent')
86
  
87
  survPreds <- vector("list", nrow(ND))
88
  for (i in 1:nrow(ND)) {
89
    Tstart = ND[i,"TestDate"]
90
    survPreds[[i]] <- survfitJM(final_jm,idVar = "PatientID", newdata = ND[1:i, ],
91
                                survTimes= seq(Tstart,min(max(Tstart+180/30,540/30),570/30),10/30),
92
    )
93
  }
94
  par(mfrow = c(1, l-1),oma = c(0, 2, 2, 2)) 
95
  for ( i in c(2:l)){
96
    plot(survPreds[[i]], estimator = "median",include.y = T,
97
         main=paste0("Follow-up time(months): ",round(survPreds[[i]]$last.time, 1)),
98
         xlab = "Time (months)",conf.int = TRUE, ylab = "", ylab2 = "" ,cex.lab =1.5,cex.main=1.5
99
    )
100
  }
101
  mtext("log mean VAF", side = 2, line = -1, outer = TRUE,cex=1.5)
102
  mtext("Recurrence-free Probability", side = 4, line = -1, outer = TRUE,cex=1.5)
103
  mtext(paste0(PID,", ",Relapse_status), outer = TRUE, cex = 1.5,side = 3,adj = 0)
104
  dev.off()
105
}                    
106
                    
107
##### leave-one-out cross-validation #####
108
## joint model
109
cores = parallel::detectCores()
110
cl <- makeCluster(10)
111
clusterExport(cl,c("mydata","mydata.id","runLOO_jm",'Tstart','predict_times'))
112
clusterEvalQ(cl, {library(splines);library(nlme);library("JMbayes");
113
  library(dplyr);library(survival);library("xtable");
114
  library("lattice")})
115
res_LOO_jm = parLapply(cl, c(1:nrow(mydata.id)), function(x) runLOO_jm(x,Tstart,predict_times,mydata,mydata.id))
116
stopCluster(cl)
117
## landmarking cox model
118
res_LOO_cox = lapply(c(1:nrow(mydata.id)),function(x) runLOO_cox(x,Tstart,predict_times,mydata,mydata.id))
119
120
## result data frame
121
dt_risk_LOO_jm = do.call("rbind",res_LOO_jm)
122
dt_risk_LOO_cox = do.call("rbind",res_LOO_cox)
123
dt_risk_LOO = merge(dt_risk_LOO_jm,dt_risk_LOO_cox,by=colnames(mydata.id),all = TRUE)
124
125
## plot the reliable diagram ~ Supplementary Fig 11
126
# at 12 month
127
reliability_diagram(
128
  list(subset(dt_risk_LOO,!is.na(prob1)) %>% select(DFS_status,DFS,prob1) ,
129
       subset(dt_risk_LOO,!is.na(cox1_prob1))%>% select(DFS_status,DFS,cox1_prob1)) ,
130
  u=12,stat_type = 'C',bins=5,c('Joint model', 'Cox model'),c('blue', 'red'),"12 Months"
131
)
132
133
# at 15 month
134
reliability_diagram(
135
  list(subset(dt_risk_LOO,!is.na(prob2)) %>% select(DFS_status,DFS,prob2) ,
136
       subset(dt_risk_LOO,!is.na(cox1_prob2))%>% select(DFS_status,DFS,cox1_prob2)) ,
137
  u=15,stat_type = 'C',bins=5,c('Joint model', 'Cox model'),c('blue', 'red'),"15 Months"
138
)