|
a |
|
b/train.py |
|
|
1 |
import os |
|
|
2 |
import random |
|
|
3 |
import subprocess |
|
|
4 |
import argparse |
|
|
5 |
import time |
|
|
6 |
import numpy as np |
|
|
7 |
import pandas as pd |
|
|
8 |
from sksurv.metrics import concordance_index_censored |
|
|
9 |
|
|
|
10 |
import torch |
|
|
11 |
from torch.utils.tensorboard import SummaryWriter |
|
|
12 |
from torch.utils.data import Dataset |
|
|
13 |
|
|
|
14 |
from model import HECTOR |
|
|
15 |
from im4MEC import Im4MEC |
|
|
16 |
from utils import * |
|
|
17 |
from utils_loss import NLLSurvLoss |
|
|
18 |
|
|
|
19 |
def set_seed(): |
|
|
20 |
random.seed(0) |
|
|
21 |
np.random.seed(0) |
|
|
22 |
torch.manual_seed(0) |
|
|
23 |
torch.cuda.manual_seed_all(0) |
|
|
24 |
torch.backends.cudnn.benchmark = False |
|
|
25 |
torch.backends.cudnn.deterministic = True |
|
|
26 |
|
|
|
27 |
def seed_worker(worker_id): |
|
|
28 |
worker_seed = torch.initial_seed() % 2**32 |
|
|
29 |
np.random.seed(worker_seed) |
|
|
30 |
random.seed(worker_seed) |
|
|
31 |
|
|
|
32 |
def evaluate_model(epoch, model, model_mol, device, loader, n_bins, writer, loss_fn, bins_values, train_BS, test_BS): |
|
|
33 |
model.eval() |
|
|
34 |
|
|
|
35 |
eval_loss = 0. |
|
|
36 |
|
|
|
37 |
all_survival_probs = np.zeros((len(loader), n_bins)) |
|
|
38 |
all_risk_scores = np.zeros((len(loader))) # This is the computed risk score. |
|
|
39 |
all_censorships = np.zeros((len(loader))) # This is the binary censorship status: 1 censored; 0 uncensored (reccured). |
|
|
40 |
all_event_times = np.zeros((len(loader))) |
|
|
41 |
|
|
|
42 |
with torch.no_grad(): |
|
|
43 |
for batch_idx, (data, features_flattened, label, event_time, censorship, stage, _) in enumerate(loader): |
|
|
44 |
data, label, censorship, stage = data.to(device), label.to(device), censorship.to(device), stage.to(device) |
|
|
45 |
_, _, Y_hat, _, _ = model_mol(features_flattened.to(device)) |
|
|
46 |
|
|
|
47 |
hazards_prob, survival_prob, Y_hat, _, _ = model(data, stage, Y_hat.squeeze(1)) # Returns hazards, survival, Y_hat, A_raw, M. |
|
|
48 |
|
|
|
49 |
# We can emphasize on the contribution of uncensored patient cases only in training by minimizing a weighted sum of the 2 losses |
|
|
50 |
loss = loss_fn(hazards=hazards_prob, S=survival_prob, Y=label, c=censorship, alpha=0) |
|
|
51 |
eval_loss += loss.item() |
|
|
52 |
|
|
|
53 |
risk = -torch.sum(survival_prob, dim=1).cpu().numpy() |
|
|
54 |
all_risk_scores[batch_idx] = risk |
|
|
55 |
all_censorships[batch_idx] = censorship.cpu().numpy() |
|
|
56 |
all_event_times[batch_idx] = event_time |
|
|
57 |
all_survival_probs[batch_idx] = survival_prob.cpu().numpy() |
|
|
58 |
|
|
|
59 |
eval_loss /= len(loader) |
|
|
60 |
|
|
|
61 |
# Compute a few survival metrics. |
|
|
62 |
c_index = concordance_index_censored( |
|
|
63 |
event_indicator=(1-all_censorships).astype(bool), |
|
|
64 |
event_time=all_event_times, |
|
|
65 |
estimate=all_risk_scores, tied_tol=1e-08)[0] |
|
|
66 |
|
|
|
67 |
# Years of interest can be adapted in utils.py |
|
|
68 |
(BS, years_of_interest), (IBS, yearI_of_interest, yearF_of_interest), (_, meanAUC), (c_index_ipcw) = compute_surv_metrics_eval(bins_values, all_survival_probs, all_risk_scores, train_BS, test_BS) |
|
|
69 |
|
|
|
70 |
print(f'Eval epoch: {epoch}, loss: {eval_loss}, c_index: {c_index}, BS at each {years_of_interest}Y: {BS}, IBS and mean cumAUC from {yearI_of_interest}Y to {yearF_of_interest}Y: {IBS} and {meanAUC}') |
|
|
71 |
|
|
|
72 |
writer.add_scalar("Loss/eval", eval_loss, epoch) |
|
|
73 |
writer.add_scalar("C_index/eval", c_index, epoch) |
|
|
74 |
for i in range(len(years_of_interest)): |
|
|
75 |
writer.add_scalar(f"eval_metrics/BS_{str(years_of_interest[i])}Y", BS[i], epoch) |
|
|
76 |
writer.add_scalar(f"eval_metrics/IBS_{str(yearI_of_interest)}Y-{str(yearF_of_interest)}Y", IBS, epoch) |
|
|
77 |
writer.add_scalar(f"eval_metrics/meanAUC_{str(yearI_of_interest)}Y-{str(yearF_of_interest)}Y", meanAUC, epoch) |
|
|
78 |
|
|
|
79 |
return eval_loss, c_index, (BS, IBS, meanAUC, c_index_ipcw) |
|
|
80 |
|
|
|
81 |
def train_one_epoch(epoch, model, model_mol, device, train_loader, optimizer, n_bins, writer, loss_fn): |
|
|
82 |
|
|
|
83 |
model.train() |
|
|
84 |
epoch_start_time = time.time() |
|
|
85 |
train_loss = 0. |
|
|
86 |
|
|
|
87 |
all_risk_scores = np.zeros((len(train_loader))) # Computed risk score. |
|
|
88 |
all_censorships = np.zeros((len(train_loader))) # Binary censorship status: 1 censored; 0 uncensored. |
|
|
89 |
all_event_times = np.zeros((len(train_loader))) # Real t event time or last follow-up. |
|
|
90 |
|
|
|
91 |
batch_start_time = time.time() |
|
|
92 |
|
|
|
93 |
for batch_idx, (data, features_flattened, label, event_time, censorship, stage, _) in enumerate(train_loader): |
|
|
94 |
|
|
|
95 |
data_load_duration = time.time() - batch_start_time |
|
|
96 |
|
|
|
97 |
data, label, censorship, stage = data.to(device), label.to(device), censorship.to(device), stage.to(device) |
|
|
98 |
# To get the image-based molecular class, non-merged features were used as this model was trained with way. |
|
|
99 |
# Merged features could be used alternatively. |
|
|
100 |
_, _, Y_hat, _, _ = model_mol(features_flattened.to(device)) |
|
|
101 |
|
|
|
102 |
# Returns hazards, survival, Y_hat, A_raw, M. |
|
|
103 |
hazards_prob, survival_prob, Y_hat, _, _ = model(data, stage, Y_hat.squeeze(1)) |
|
|
104 |
|
|
|
105 |
# Loss. |
|
|
106 |
loss = loss_fn(hazards=hazards_prob, S=survival_prob, Y=label, c=censorship) |
|
|
107 |
train_loss += loss.item() |
|
|
108 |
|
|
|
109 |
# Store outputs. |
|
|
110 |
risk = -torch.sum(survival_prob, dim=1).detach().cpu().numpy() |
|
|
111 |
all_risk_scores[batch_idx] = risk |
|
|
112 |
all_censorships[batch_idx] = censorship.item() |
|
|
113 |
all_event_times[batch_idx] = event_time |
|
|
114 |
|
|
|
115 |
# Backward pass. |
|
|
116 |
loss.backward() |
|
|
117 |
|
|
|
118 |
# Step. |
|
|
119 |
optimizer.step() |
|
|
120 |
optimizer.zero_grad() |
|
|
121 |
|
|
|
122 |
batch_duration = time.time() - batch_start_time |
|
|
123 |
batch_start_time = time.time() |
|
|
124 |
|
|
|
125 |
writer.add_scalar("duration/data_load", data_load_duration, epoch) |
|
|
126 |
writer.add_scalar("duration/batch", batch_duration, epoch) |
|
|
127 |
|
|
|
128 |
epoch_duration = time.time() - epoch_start_time |
|
|
129 |
print(f"Finished training on epoch {epoch} in {epoch_duration:.2f}s") |
|
|
130 |
|
|
|
131 |
train_loss /= len(train_loader) |
|
|
132 |
|
|
|
133 |
train_c_index = concordance_index_censored( |
|
|
134 |
event_indicator=(1-all_censorships).astype(bool), |
|
|
135 |
event_time=all_event_times, |
|
|
136 |
estimate=all_risk_scores, tied_tol=1e-08)[0] |
|
|
137 |
|
|
|
138 |
print(f'Epoch: {epoch}, epoch_duration : {epoch_duration}, train_loss: {train_loss}, train_c_index: {train_c_index}') |
|
|
139 |
|
|
|
140 |
filepath = os.path.join(writer.log_dir, f"{epoch}_checkpoint.pt") |
|
|
141 |
print(f"Saving model to {filepath}") |
|
|
142 |
torch.save(model.state_dict(), filepath) |
|
|
143 |
|
|
|
144 |
writer.add_scalar("duration/epoch", epoch_duration, epoch) |
|
|
145 |
writer.add_scalar("LR", get_lr(optimizer), epoch) |
|
|
146 |
writer.add_scalar("Loss/train", train_loss, epoch) |
|
|
147 |
writer.add_scalar("C_index/train", train_c_index, epoch) |
|
|
148 |
|
|
|
149 |
def run_train_eval_loop(train_loader, val_loader, loss_fn, hparams, run_id, BS_data, checkpoint_model_molecular): |
|
|
150 |
writer = SummaryWriter(os.path.join("./runs", run_id)) |
|
|
151 |
device = torch.device("cuda") |
|
|
152 |
n_bins = hparams["n_bins"] |
|
|
153 |
|
|
|
154 |
model = HECTOR( |
|
|
155 |
input_feature_size=hparams["input_feature_size"], |
|
|
156 |
precompression_layer=hparams["precompression_layer"], |
|
|
157 |
feature_size_comp=hparams["feature_size_comp"], |
|
|
158 |
feature_size_attn=hparams["feature_size_attn"], |
|
|
159 |
postcompression_layer=hparams["postcompression_layer"], |
|
|
160 |
feature_size_comp_post=hparams["feature_size_comp_post"], |
|
|
161 |
dropout=True, |
|
|
162 |
p_dropout_fc=hparams["p_dropout_fc"], |
|
|
163 |
p_dropout_atn=hparams["p_dropout_atn"], |
|
|
164 |
n_classes=n_bins, |
|
|
165 |
|
|
|
166 |
input_stage_size=hparams["input_stage_size"], |
|
|
167 |
embedding_dim_stage=hparams["embedding_dim_stage"], |
|
|
168 |
depth_dim_stage=hparams["depth_dim_stage"], |
|
|
169 |
act_fct_stage=hparams["act_fct_stage"], |
|
|
170 |
dropout_stage=hparams["dropout_stage"], |
|
|
171 |
p_dropout_stage=hparams["p_dropout_stage"], |
|
|
172 |
|
|
|
173 |
input_mol_size=4, |
|
|
174 |
embedding_dim_mol=hparams["embedding_dim_mol"], |
|
|
175 |
depth_dim_mol=hparams["depth_dim_mol"], |
|
|
176 |
act_fct_mol=hparams["act_fct_mol"], |
|
|
177 |
dropout_mol=hparams["dropout_mol"], |
|
|
178 |
p_dropout_mol=hparams["p_dropout_mol"], |
|
|
179 |
|
|
|
180 |
fusion_type=hparams["fusion_type"], |
|
|
181 |
use_bilinear=hparams["use_bilinear"], |
|
|
182 |
gate_hist=hparams["gate_hist"], |
|
|
183 |
gate_stage=hparams["gate_stage"], |
|
|
184 |
gate_mol=hparams["gate_mol"], |
|
|
185 |
scale=hparams["scale"], |
|
|
186 |
).to(device) |
|
|
187 |
print('model') |
|
|
188 |
print_model(model) |
|
|
189 |
|
|
|
190 |
# This model is instance with the trained weights towards molecular classification and will be used in inference mode only. |
|
|
191 |
# NOTE: it is important that the molecular model, here im4MEC, has been trained on the same patients as training to avoid patient-level information leakage. |
|
|
192 |
model_mol = Im4MEC( |
|
|
193 |
input_feature_size=hparams["input_feature_size"], |
|
|
194 |
precompression_layer=True, |
|
|
195 |
feature_size_comp=hparams["feature_size_comp_molecular"], |
|
|
196 |
feature_size_attn=hparams["feature_size_attn_molecular"], |
|
|
197 |
n_classes=hparams["n_classes_molecular"], |
|
|
198 |
dropout=True, # Not used in inference. |
|
|
199 |
p_dropout_fc=0.25, |
|
|
200 |
p_dropout_atn=0.25, |
|
|
201 |
).to(device) |
|
|
202 |
|
|
|
203 |
msg = model_mol.load_state_dict(torch.load(checkpoint_model_molecular, map_location=device), strict=True) |
|
|
204 |
print(msg) |
|
|
205 |
|
|
|
206 |
for p in model_mol.parameters(): |
|
|
207 |
p.requires_grad = False |
|
|
208 |
print(f"HECTOR and plugged-in im4MEC are built and checkpoints loaded") |
|
|
209 |
model_mol.eval() |
|
|
210 |
|
|
|
211 |
optimizer = torch.optim.Adam( |
|
|
212 |
filter(lambda p: p.requires_grad, model.parameters()), |
|
|
213 |
lr=hparams["initial_lr"], |
|
|
214 |
weight_decay=hparams["weight_decay"], |
|
|
215 |
) |
|
|
216 |
|
|
|
217 |
# Using a multi-step LR decay routine. |
|
|
218 |
milestones = [int(x) for x in hparams["milestones"].split(",")] |
|
|
219 |
scheduler = torch.optim.lr_scheduler.MultiStepLR( |
|
|
220 |
optimizer, milestones=milestones, gamma=hparams["gamma_lr"] |
|
|
221 |
) |
|
|
222 |
|
|
|
223 |
monitor_tracker = MonitorBestModelEarlyStopping( |
|
|
224 |
patience=hparams["earlystop_patience"], |
|
|
225 |
min_epochs=hparams["earlystop_min_epochs"], |
|
|
226 |
saving_checkpoint=True, |
|
|
227 |
) |
|
|
228 |
|
|
|
229 |
for epoch in range(hparams["max_epochs"]): |
|
|
230 |
|
|
|
231 |
train_one_epoch(epoch, model, model_mol, device, train_loader, optimizer, n_bins, writer, loss_fn) |
|
|
232 |
|
|
|
233 |
# Evaluation on validation set. |
|
|
234 |
print("Evaluating model on validation set...") |
|
|
235 |
eval_loss, eval_cindex, eval_other_metrics = evaluate_model(epoch, model, model_mol, device, val_loader, n_bins, writer, loss_fn, hparams["bins_values"], *BS_data) |
|
|
236 |
monitor_tracker(epoch, eval_loss, eval_cindex, eval_other_metrics, model, writer.log_dir) |
|
|
237 |
|
|
|
238 |
# Update LR decay. |
|
|
239 |
scheduler.step() |
|
|
240 |
|
|
|
241 |
if monitor_tracker.early_stop: |
|
|
242 |
print(f"Early stop criterion reached. Broke off training loop after epoch {epoch}.") |
|
|
243 |
break |
|
|
244 |
|
|
|
245 |
# Log the hyperparameters of the experiments. |
|
|
246 |
runs_history = { |
|
|
247 |
"run_id" : run_id, |
|
|
248 |
"best_epoch_CI" : monitor_tracker.best_epoch_CI, |
|
|
249 |
"best_CI_score" : monitor_tracker.best_CI_score, |
|
|
250 |
"best_epoch_loss": monitor_tracker.best_epoch_loss, |
|
|
251 |
"best_evalLoss" : monitor_tracker.eval_loss_min, |
|
|
252 |
"BS" : monitor_tracker.best_metrics_score[0], |
|
|
253 |
"IBS" : monitor_tracker.best_metrics_score[1], |
|
|
254 |
"cumMeanAUC" : monitor_tracker.best_metrics_score[2], |
|
|
255 |
"CI_ipwc" : monitor_tracker.best_metrics_score[3], |
|
|
256 |
**hparams, |
|
|
257 |
} |
|
|
258 |
with open('runs_history.txt', 'a') as filehandle: |
|
|
259 |
for _, value in runs_history.items(): |
|
|
260 |
filehandle.write('%s;' % value) |
|
|
261 |
filehandle.write('\n') |
|
|
262 |
|
|
|
263 |
writer.close() |
|
|
264 |
|
|
|
265 |
def prepare_datasets(args): |
|
|
266 |
|
|
|
267 |
df = pd.read_csv(args.manifest) |
|
|
268 |
|
|
|
269 |
n_bins = len(df['disc_label'].unique()) |
|
|
270 |
assert n_bins == args.n_bins, 'mismatch between the number of bins passed in args and classes in dataset' |
|
|
271 |
bins_values = get_bins_time_value(df, n_bins, time_col_name='recurrence_years', label_time_col_name='disc_label') |
|
|
272 |
assert len(bins_values)==n_bins |
|
|
273 |
print(f'Read {args.manifest} dataset containing {len(df)} samples with {n_bins} bins of following values {bins_values}') |
|
|
274 |
|
|
|
275 |
# NOTE: you may need to use the two lines below depending on how the category is listed in the csv file. |
|
|
276 |
#df.stage = df.stage.apply(lambda x : 'III' if 'III' in x else ('II' if 'II' in x else 'I')).astype("category") |
|
|
277 |
#df.stage = pd.Categorical(df['stage'], categories=['I', 'II', 'III'], ordered=True).codes |
|
|
278 |
print(f'stage taxonomy used: {df.stage.unique()}') |
|
|
279 |
|
|
|
280 |
try: |
|
|
281 |
training_set = df[df["split"] == "training"] |
|
|
282 |
validation_set = df[df["split"] == "validation"] |
|
|
283 |
except: |
|
|
284 |
raise Exception( |
|
|
285 |
f"Could not find training and validation splits in {args.manifest}" |
|
|
286 |
) |
|
|
287 |
|
|
|
288 |
train_split = FeatureBagsDataset(df=training_set, |
|
|
289 |
data_dir=args.data_dir, |
|
|
290 |
input_feature_size=args.input_feature_size, |
|
|
291 |
stage_class=len(training_set.stage.unique())) |
|
|
292 |
|
|
|
293 |
val_split = FeatureBagsDataset(df=validation_set, |
|
|
294 |
data_dir=args.data_dir, |
|
|
295 |
input_feature_size=args.input_feature_size, |
|
|
296 |
stage_class=len(validation_set.stage.unique())) |
|
|
297 |
|
|
|
298 |
# To compute the Brier score (BS), you need a specific format of censorship and times. |
|
|
299 |
_, train_BS = get_survival_data_for_BS(training_set, time_col_name='recurrence_years') |
|
|
300 |
_, test_BS = get_survival_data_for_BS(validation_set, time_col_name='recurrence_years') |
|
|
301 |
|
|
|
302 |
return train_split, val_split, train_BS, test_BS, bins_values, len(df.stage.unique()) |
|
|
303 |
|
|
|
304 |
|
|
|
305 |
def main(args): |
|
|
306 |
|
|
|
307 |
# Set random seed for some degree of reproducibility. See PyTorch docs on this topic for caveats. |
|
|
308 |
# https://pytorch.org/docs/stable/notes/randomness.html#reproducibility |
|
|
309 |
set_seed() |
|
|
310 |
|
|
|
311 |
if not torch.cuda.is_available(): |
|
|
312 |
raise Exception( |
|
|
313 |
"No CUDA device available. Training without one is not feasible." |
|
|
314 |
) |
|
|
315 |
|
|
|
316 |
git_sha = subprocess.check_output(["git", "describe", "--always"]).strip().decode("utf-8") |
|
|
317 |
train_run_id = f"{git_sha}_hp{args.hp}_{time.strftime('%Y%m%d-%H%M')}" |
|
|
318 |
|
|
|
319 |
train_split, val_split, train_BS, test_BS, bins_values, stage_taxonomy = prepare_datasets(args) |
|
|
320 |
|
|
|
321 |
print(f"=> Run ID {train_run_id}") |
|
|
322 |
print(f"=> Training on {len(train_split)} samples") |
|
|
323 |
print(f"=> Validating on {len(val_split)} samples") |
|
|
324 |
|
|
|
325 |
base_hparams = dict( |
|
|
326 |
# Preprocessing settings. This should be changed with the dataset called accordingly. |
|
|
327 |
# Storing values here for readibility. |
|
|
328 |
n_bins=args.n_bins, # Partion on the continuous time scale. |
|
|
329 |
bins_values=bins_values, |
|
|
330 |
input_feature_size=args.input_feature_size, |
|
|
331 |
features_extraction=os.path.dirname(args.data_dir), |
|
|
332 |
|
|
|
333 |
# Settings that be changed in the loop: |
|
|
334 |
# Training. |
|
|
335 |
sampling_method="random", |
|
|
336 |
max_epochs=100, |
|
|
337 |
earlystop_warmup=0, |
|
|
338 |
earlystop_patience=30, |
|
|
339 |
earlystop_min_epochs=30, |
|
|
340 |
|
|
|
341 |
# Loss. |
|
|
342 |
alpha_surv = 0.0, |
|
|
343 |
|
|
|
344 |
# Optimizer. |
|
|
345 |
initial_lr=0.00003, |
|
|
346 |
milestones="2, 5, 15, 25", |
|
|
347 |
gamma_lr=0.1, |
|
|
348 |
weight_decay=0.00001, |
|
|
349 |
|
|
|
350 |
# Model architecture parameters. See model class for details. |
|
|
351 |
precompression_layer=True, |
|
|
352 |
feature_size_comp=512, |
|
|
353 |
feature_size_attn=256, |
|
|
354 |
postcompression_layer=True, |
|
|
355 |
feature_size_comp_post=128, |
|
|
356 |
p_dropout_fc=0.25, |
|
|
357 |
p_dropout_atn=0.25, |
|
|
358 |
|
|
|
359 |
# Model of molecular classification. In our case only inference is used. |
|
|
360 |
n_classes_molecular=args.n_classes_molecular, |
|
|
361 |
feature_size_comp_molecular=args.feature_size_comp_molecular, |
|
|
362 |
feature_size_attn_molecular=args.feature_size_attn_molecular, |
|
|
363 |
|
|
|
364 |
# Fusion parameters. |
|
|
365 |
input_stage_size=stage_taxonomy, |
|
|
366 |
embedding_dim_stage=16, |
|
|
367 |
depth_dim_stage=1, |
|
|
368 |
act_fct_stage='elu', |
|
|
369 |
dropout_stage=True, |
|
|
370 |
p_dropout_stage=0.25, |
|
|
371 |
embedding_dim_mol=16, |
|
|
372 |
depth_dim_mol=1, |
|
|
373 |
act_fct_mol='elu', |
|
|
374 |
dropout_mol=True, |
|
|
375 |
p_dropout_mol=0.25, |
|
|
376 |
fusion_type='bilinear', |
|
|
377 |
use_bilinear=[True,True,True], |
|
|
378 |
gate_hist=True, |
|
|
379 |
gate_stage=True, |
|
|
380 |
gate_mol=True, |
|
|
381 |
scale=[2,1,1], |
|
|
382 |
) |
|
|
383 |
|
|
|
384 |
hparam_sets = [ |
|
|
385 |
{ |
|
|
386 |
**base_hparams, |
|
|
387 |
}, |
|
|
388 |
] |
|
|
389 |
|
|
|
390 |
hps = hparam_sets[args.hp] |
|
|
391 |
|
|
|
392 |
|
|
|
393 |
train_loader, val_loader = define_data_sampling( |
|
|
394 |
train_split, |
|
|
395 |
val_split, |
|
|
396 |
method=hps["sampling_method"], |
|
|
397 |
workers=args.workers, |
|
|
398 |
) |
|
|
399 |
|
|
|
400 |
run_train_eval_loop( |
|
|
401 |
train_loader=train_loader, |
|
|
402 |
val_loader=val_loader, |
|
|
403 |
loss_fn = NLLSurvLoss(alpha=hps["alpha_surv"]), # Used the Negative log likelihood loss. |
|
|
404 |
hparams=hps, |
|
|
405 |
run_id=train_run_id, |
|
|
406 |
BS_data = (train_BS, test_BS), |
|
|
407 |
checkpoint_model_molecular=args.checkpoint_model_molecular, |
|
|
408 |
) |
|
|
409 |
print("Finished training.") |
|
|
410 |
|
|
|
411 |
def get_args_parser(): |
|
|
412 |
|
|
|
413 |
parser = argparse.ArgumentParser('Training script', add_help=False) |
|
|
414 |
|
|
|
415 |
parser.add_argument( |
|
|
416 |
"--manifest", |
|
|
417 |
type=str, |
|
|
418 |
help="CSV file listing all slides, their labels, and which split (train/test/val) they belong to.", |
|
|
419 |
) |
|
|
420 |
parser.add_argument( |
|
|
421 |
"--n_bins", |
|
|
422 |
type=int, |
|
|
423 |
help="Number of time intervals used to create the time labels. It should be the same as the manifest.", |
|
|
424 |
) |
|
|
425 |
parser.add_argument( |
|
|
426 |
"--data_dir", |
|
|
427 |
type=str, |
|
|
428 |
help="Directory where all *_features.h5 files are stored", |
|
|
429 |
) |
|
|
430 |
parser.add_argument( |
|
|
431 |
"--input_feature_size", |
|
|
432 |
help="The size of the input features from the feature bags. Recommend going by blocks from these output size [96, 96, 192, 192, 384, 384, 384, 384, 768, 768]", |
|
|
433 |
type=int, |
|
|
434 |
required=True, |
|
|
435 |
) |
|
|
436 |
parser.add_argument( |
|
|
437 |
"--checkpoint_model_molecular", |
|
|
438 |
type=str, |
|
|
439 |
default='', |
|
|
440 |
help="Path to checkpoint of im4MEC", |
|
|
441 |
) |
|
|
442 |
parser.add_argument( |
|
|
443 |
"--n_classes_molecular", |
|
|
444 |
type=int, |
|
|
445 |
required=True, |
|
|
446 |
help="", |
|
|
447 |
) |
|
|
448 |
parser.add_argument( |
|
|
449 |
"--feature_size_comp_molecular", |
|
|
450 |
type=int, |
|
|
451 |
required=True, |
|
|
452 |
help="Size of the model of the trained im4MEC. See in im4MEC.py", |
|
|
453 |
) |
|
|
454 |
parser.add_argument( |
|
|
455 |
"--feature_size_attn_molecular", |
|
|
456 |
type=int, |
|
|
457 |
required=True, |
|
|
458 |
help="Size of the model of the trained im4MEC. See in im4MEC.py", |
|
|
459 |
) |
|
|
460 |
parser.add_argument( |
|
|
461 |
"--workers", |
|
|
462 |
help="The number of workers to use for the data loaders.", |
|
|
463 |
type=int, |
|
|
464 |
default=4, |
|
|
465 |
) |
|
|
466 |
parser.add_argument( |
|
|
467 |
"--hp", |
|
|
468 |
type=int, |
|
|
469 |
required=True, |
|
|
470 |
) |
|
|
471 |
|
|
|
472 |
return parser |
|
|
473 |
|
|
|
474 |
if __name__ == "__main__": |
|
|
475 |
|
|
|
476 |
parser = argparse.ArgumentParser('Training script', parents=[get_args_parser()]) |
|
|
477 |
args = parser.parse_args() |
|
|
478 |
|
|
|
479 |
main(args) |