"""
"""
import argparse
import logging
import os
import sys
import textwrap
from collections import OrderedDict, deque
from copy import deepcopy
from pathlib import Path
from typing import Dict, Optional, Tuple
import numpy as np
import torch
from tensorboardX import SummaryWriter
from torch import nn, optim
from torch.nn.parallel import DataParallel as DP
from torch.nn.parallel import DistributedDataParallel as DDP # noqa: F401
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
try:
import torch_ecg # noqa: F401
except ModuleNotFoundError:
sys.path.insert(0, str(Path(__file__).absolute().parents[2]))
from cfg import ModelCfg, TrainCfg
# from dataset import CPSC2020
from dataset_simplified import CPSC2020 as CPSC2020_SIMPLIFIED
from metrics import CPSC2020_loss, CPSC2020_score, eval_score # noqa: F401
from model import ECG_CRNN_CPSC2020, ECG_SEQ_LAB_NET_CPSC2020
from torch_ecg.cfg import CFG
from torch_ecg.components.outputs import BaseOutput # noqa: F401
from torch_ecg.components.trainer import BaseTrainer # noqa: F401
from torch_ecg.models.loss import BCEWithLogitsWithClassWeightLoss
from torch_ecg.utils.misc import dict_to_str, get_date_str, list_sum, str2bool
from torch_ecg.utils.utils_data import mask_to_intervals
from torch_ecg.utils.utils_nn import default_collate_fn as collate_fn
if ModelCfg.torch_dtype == torch.float64:
torch.set_default_tensor_type(torch.DoubleTensor)
_DTYPE = torch.float64
else:
_DTYPE = torch.float32
__all__ = [
"train",
]
def train(
model: nn.Module,
model_config: dict,
device: torch.device,
config: dict,
logger: Optional[logging.Logger] = None,
debug: bool = False,
) -> OrderedDict:
"""
Parameters
----------
model: Module,
the model to train
model_config: dict,
config of the model, to store into the checkpoints
device: torch.device,
device on which the model trains
config: dict,
configurations of training, ref. `ModelCfg`, `TrainCfg`, etc.
logger: Logger, optional,
logger
debug: bool, default False,
if True, the training set itself would be evaluated
to check if the model really learns from the training set
Returns
-------
best_state_dict: OrderedDict,
state dict of the best model
"""
msg = f"training configurations are as follows:\n{dict_to_str(config)}"
config = CFG(config)
if logger:
logger.info(msg)
else:
print(msg)
if type(model).__name__ in [
"DataParallel",
]: # TODO: further consider "DistributedDataParallel"
_model = model.module
else:
_model = model
config.log_dir = Path(config.log_dir)
config.log_dir.mkdir(parents=True, exist_ok=True)
config.checkpoints = Path(config.checkpoints)
config.checkpoints.mkdir(parents=True, exist_ok=True)
config.model_dir = Path(config.model_dir)
config.model_dir.mkdir(parents=True, exist_ok=True)
ds = CPSC2020_SIMPLIFIED
train_dataset = ds(config=config, training=True)
train_dataset.__DEBUG__ = False
if debug:
val_train_dataset = ds(config=config, training=True)
val_train_dataset.disable_data_augmentation()
val_train_dataset.__DEBUG__ = False
val_dataset = ds(config=config, training=False)
val_dataset.__DEBUG__ = False
n_train = len(train_dataset)
n_val = len(val_dataset)
n_epochs = config.n_epochs
batch_size = config.batch_size
lr = config.learning_rate
# https://discuss.pytorch.org/t/guidelines-for-assigning-num-workers-to-dataloader/813/4
num_workers = 4 * (torch.cuda.device_count() or 1)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
drop_last=False,
collate_fn=collate_fn,
)
if debug:
val_train_loader = DataLoader(
dataset=val_train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
drop_last=False,
collate_fn=collate_fn,
)
val_loader = DataLoader(
dataset=val_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
drop_last=False,
collate_fn=collate_fn,
)
writer = SummaryWriter(
log_dir=str(config.log_dir),
filename_suffix=f"OPT_{config.model_name}_{config.cnn_name}_{config.train_optimizer}_LR_{lr}_BS_{batch_size}",
comment=f"OPT_{config.model_name}_{config.cnn_name}_{config.train_optimizer}_LR_{lr}_BS_{batch_size}",
)
# max_itr = n_epochs * n_train
msg = textwrap.dedent(
f"""
Starting training:
------------------
Epochs: {n_epochs}
Batch size: {batch_size}
Learning rate: {lr}
Training size: {n_train}
Validation size: {n_val}
Device: {device.type}
Optimizer: {config.train_optimizer}
-----------------------------------------
"""
)
# print(msg) # in case no logger
if logger:
logger.info(msg)
else:
print(msg)
if config.train_optimizer.lower() == "adam":
optimizer = optim.Adam(
params=model.parameters(),
lr=lr,
betas=config.betas,
eps=1e-08, # default
)
elif config.train_optimizer.lower() in ["adamw", "adamw_amsgrad"]:
optimizer = optim.AdamW(
params=model.parameters(),
lr=lr,
betas=config.betas,
weight_decay=config.decay,
eps=1e-08, # default
amsgrad=config.train_optimizer.lower().endswith("amsgrad"),
)
elif config.train_optimizer.lower() == "sgd":
optimizer = optim.SGD(
params=model.parameters(),
lr=lr,
momentum=config.momentum,
weight_decay=config.decay,
)
else:
raise NotImplementedError(f"optimizer `{config.train_optimizer}` not implemented!")
# scheduler = optim.lr_scheduler.LambdaLR(optimizer, burnin_schedule)
if config.lr_scheduler is None:
scheduler = None
elif config.lr_scheduler.lower() == "plateau":
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
elif config.lr_scheduler.lower() == "step":
scheduler = optim.lr_scheduler.StepLR(optimizer, config.lr_step_size, config.lr_gamma)
elif config.lr_scheduler.lower() in [
"one_cycle",
"onecycle",
]:
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer=optimizer,
max_lr=config.max_lr,
epochs=n_epochs,
steps_per_epoch=len(train_loader),
)
else:
raise NotImplementedError(f"lr scheduler `{config.lr_scheduler.lower()}` not implemented for training")
if config.loss == "BCEWithLogitsLoss":
criterion = nn.BCEWithLogitsLoss()
elif config.loss == "BCEWithLogitsWithClassWeightLoss":
criterion = BCEWithLogitsWithClassWeightLoss(class_weight=train_dataset.class_weights.to(device=device, dtype=_DTYPE))
else:
raise NotImplementedError(f"loss `{config.loss}` not implemented!")
# scheduler = ReduceLROnPlateau(optimizer, mode="max", verbose=True, patience=6, min_lr=1e-7)
# scheduler = CosineAnnealingWarmRestarts(optimizer, 0.001, 1e-6, 20)
save_prefix = f"{_model.__name__}_{config.cnn_name}_{config.rnn_name}_epoch"
best_state_dict = OrderedDict()
best_challenge_metric = -np.inf
best_eval_res = tuple()
best_epoch = -1
pseudo_best_epoch = -1
saved_models = deque()
model.train()
global_step = 0
for epoch in range(n_epochs):
model.train()
epoch_loss = 0
with tqdm(
total=n_train,
desc=f"Epoch {epoch + 1}/{n_epochs}",
dynamic_ncols=True,
mininterval=1.0,
) as pbar:
for epoch_step, (signals, labels) in enumerate(train_loader):
global_step += 1
signals = signals.to(device=device, dtype=_DTYPE)
labels = labels.to(device=device, dtype=_DTYPE)
preds = model(signals)
loss = criterion(preds, labels)
if config.flooding_level > 0:
flood = (loss - config.flooding_level).abs() + config.flooding_level
epoch_loss += loss.item()
optimizer.zero_grad()
flood.backward()
else:
epoch_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
if global_step % config.log_step == 0:
writer.add_scalar("train/loss", loss.item(), global_step)
if scheduler:
writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
pbar.set_postfix(
**{
"loss (batch)": loss.item(),
"lr": scheduler.get_lr()[0],
}
)
msg = f"Train step_{global_step}: loss : {loss.item()}, lr : {scheduler.get_lr()[0] * batch_size}"
else:
pbar.set_postfix(
**{
"loss (batch)": loss.item(),
}
)
msg = f"Train step_{global_step}: loss : {loss.item()}"
# print(msg) # in case no logger
if config.flooding_level > 0:
writer.add_scalar("train/flood", flood.item(), global_step)
msg = f"{msg}\nflood : {flood.item()}"
if logger:
logger.info(msg)
else:
print(msg)
pbar.update(signals.shape[0])
writer.add_scalar("train/epoch_loss", epoch_loss, global_step)
# eval for each epoch using corresponding `evaluate` function
if debug:
if config.model_name == "crnn":
eval_train_res = evaluate_crnn(model, val_train_loader, config, device, debug)
writer.add_scalar("train/auroc", eval_train_res[0], global_step)
writer.add_scalar("train/auprc", eval_train_res[1], global_step)
writer.add_scalar("train/accuracy", eval_train_res[2], global_step)
writer.add_scalar("train/f_measure", eval_train_res[3], global_step)
writer.add_scalar("train/f_beta_measure", eval_train_res[4], global_step)
writer.add_scalar("train/g_beta_measure", eval_train_res[5], global_step)
elif config.model_name == "seq_lab":
eval_train_res = evaluate_seq_lab(model, val_train_loader, config, device, debug)
writer.add_scalar("train/total_loss", eval_train_res.total_loss, global_step)
writer.add_scalar("train/spb_loss", eval_train_res.spb_loss, global_step)
writer.add_scalar("train/pvc_loss", eval_train_res.pvc_loss, global_step)
writer.add_scalar("train/spb_tp", eval_train_res.spb_tp, global_step)
writer.add_scalar("train/pvc_tp", eval_train_res.pvc_tp, global_step)
writer.add_scalar("train/spb_fp", eval_train_res.spb_fp, global_step)
writer.add_scalar("train/pvc_fp", eval_train_res.pvc_fp, global_step)
writer.add_scalar("train/spb_fn", eval_train_res.spb_fn, global_step)
writer.add_scalar("train/pvc_fn", eval_train_res.pvc_fn, global_step)
if config.model_name == "crnn":
eval_res = evaluate_crnn(model, val_loader, config, device, debug)
model.train()
writer.add_scalar("test/auroc", eval_res[0], global_step)
writer.add_scalar("test/auprc", eval_res[1], global_step)
writer.add_scalar("test/accuracy", eval_res[2], global_step)
writer.add_scalar("test/f_measure", eval_res[3], global_step)
writer.add_scalar("test/f_beta_measure", eval_res[4], global_step)
writer.add_scalar("test/g_beta_measure", eval_res[5], global_step)
if config.lr_scheduler is None:
pass
elif config.lr_scheduler.lower() == "plateau":
scheduler.step(metrics=eval_res[4])
elif config.lr_scheduler.lower() == "step":
scheduler.step()
elif config.lr_scheduler.lower() in [
"one_cycle",
"onecycle",
]:
scheduler.step()
if debug:
eval_train_msg = textwrap.dedent(
f"""
train/auroc: {eval_train_res[0]}
train/auprc: {eval_train_res[1]}
train/accuracy: {eval_train_res[2]}
train/f_measure: {eval_train_res[3]}
train/f_beta_measure: {eval_train_res[4]}
train/g_beta_measure: {eval_train_res[5]}
"""
)
else:
eval_train_msg = ""
msg = textwrap.dedent(
f"""
Train epoch_{epoch + 1}:
--------------------
train/epoch_loss: {epoch_loss}{eval_train_msg}
test/auroc: {eval_res[0]}
test/auprc: {eval_res[1]}
test/accuracy: {eval_res[2]}
test/f_measure: {eval_res[3]}
test/f_beta_measure: {eval_res[4]}
test/g_beta_measure: {eval_res[5]}
---------------------------------
"""
)
elif config.model_name == "seq_lab":
eval_res = evaluate_seq_lab(model, val_loader, config, device, debug)
model.train()
writer.add_scalar("test/total_loss", eval_res.total_loss, global_step)
writer.add_scalar("test/spb_loss", eval_res.spb_loss, global_step)
writer.add_scalar("test/pvc_loss", eval_res.pvc_loss, global_step)
writer.add_scalar("test/spb_tp", eval_res.spb_tp, global_step)
writer.add_scalar("test/pvc_tp", eval_res.pvc_tp, global_step)
writer.add_scalar("test/spb_fp", eval_res.spb_fp, global_step)
writer.add_scalar("test/pvc_fp", eval_res.pvc_fp, global_step)
writer.add_scalar("test/spb_fn", eval_res.spb_fn, global_step)
writer.add_scalar("test/pvc_fn", eval_res.pvc_fn, global_step)
if config.lr_scheduler is None:
pass
elif config.lr_scheduler.lower() == "plateau":
scheduler.step(metrics=-eval_res.total_loss)
elif config.lr_scheduler.lower() == "step":
scheduler.step()
elif config.lr_scheduler.lower() in [
"one_cycle",
"onecycle",
]:
scheduler.step()
if debug:
eval_train_msg = textwrap.dedent(
f"""
train/total_loss: {eval_train_res.total_loss}
train/spb_loss: {eval_train_res.spb_loss}
train/pvc_loss: {eval_train_res.pvc_loss}
train/spb_tp: {eval_train_res.spb_tp}
train/pvc_tp: {eval_train_res.pvc_tp}
train/spb_fp: {eval_train_res.spb_fp}
train/pvc_fp: {eval_train_res.pvc_fp}
train/spb_fn: {eval_train_res.spb_fn}
train/pvc_fn: {eval_train_res.pvc_fn}
"""
)
else:
eval_train_msg = ""
msg = textwrap.dedent(
f"""
Train epoch_{epoch + 1}:
--------------------
train/epoch_loss: {epoch_loss}{eval_train_msg}
test/total_loss: {eval_res.total_loss}
test/spb_loss: {eval_res.spb_loss}
test/pvc_loss: {eval_res.pvc_loss}
test/spb_tp: {eval_res.spb_tp}
test/pvc_tp: {eval_res.pvc_tp}
test/spb_fp: {eval_res.spb_fp}
test/pvc_fp: {eval_res.pvc_fp}
test/spb_fn: {eval_res.spb_fn}
test/pvc_fn: {eval_res.pvc_fn}
---------------------------------
"""
)
# print(msg) # in case no logger
if logger:
logger.info(msg)
else:
print(msg)
monitor = eval_res[4] if config.model_name == "crnn" else -eval_res.total_loss
if monitor > best_challenge_metric:
best_challenge_metric = monitor
best_state_dict = _model.state_dict()
best_eval_res = deepcopy(eval_res)
best_epoch = epoch + 1
pseudo_best_epoch = epoch + 1
elif config.early_stopping:
if monitor >= best_challenge_metric - config.early_stopping.min_delta:
pseudo_best_epoch = epoch + 1
elif epoch - pseudo_best_epoch > config.early_stopping.patience:
msg = f"early stopping is triggered at epoch {epoch + 1}"
if logger:
logger.info(msg)
else:
print(msg)
break
msg = textwrap.dedent(
f"""
best challenge metric = {best_challenge_metric},
obtained at epoch {best_epoch}
"""
)
if logger:
logger.info(msg)
else:
print(msg)
try:
config.checkpoints.mkdir(parents=True, exist_ok=True)
# if logger:
# logger.info("Created checkpoint directory")
except OSError:
pass
if config.model_name == "crnn":
save_suffix = f"epochloss_{epoch_loss:.5f}_fb_{eval_res[4]:.2f}_gb_{eval_res[5]:.2f}"
elif config.model_name == "seq_lab":
save_suffix = f"epochloss_{epoch_loss:.5f}_challenge_loss_{eval_res.total_loss}"
save_filename = f"{save_prefix}{epoch + 1}_{get_date_str()}_{save_suffix}.pth.tar"
save_path = config.checkpoints / save_filename
torch.save(
{
"model_state_dict": _model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"model_config": model_config,
"train_config": config,
"epoch": epoch + 1,
},
str(save_path),
)
if logger:
logger.info(f"Checkpoint {epoch + 1} saved!")
saved_models.append(save_path)
# remove outdated models
if len(saved_models) > config.keep_checkpoint_max > 0:
model_to_remove = saved_models.popleft()
try:
os.remove(model_to_remove)
except Exception:
logger.info(f"failed to remove {model_to_remove}")
# save the best model
if best_challenge_metric > -np.inf:
if config.final_model_name:
save_filename = config.final_model_name
else:
if config.model_name == "crnn":
save_suffix = f"BestModel_fb_{best_eval_res[4]:.2f}_gb_{best_eval_res[5]:.2f}"
elif config.model_name == "seq_lab":
save_suffix = f"BestModel_challenge_loss_{best_eval_res.total_loss}"
save_filename = f"{save_prefix}_{get_date_str()}_{save_suffix}.pth.tar"
save_path = config.model_dir / save_filename
torch.save(
{
"model_state_dict": best_state_dict,
"model_config": model_config,
"train_config": config,
"epoch": best_epoch,
},
str(save_path),
)
if logger:
logger.info(f"Best model saved to {save_path}!")
writer.close()
if logger:
for h in logger.handlers:
h.close()
logger.removeHandler(h)
del logger
logging.shutdown()
return best_state_dict
@torch.no_grad()
def evaluate_crnn(
model: nn.Module,
data_loader: DataLoader,
config: dict,
device: torch.device,
debug: bool = True,
logger: Optional[logging.Logger] = None,
) -> Tuple[float]:
"""
Parameters
----------
model: Module,
the model to evaluate
data_loader: DataLoader,
the data loader for loading data for evaluation
config: dict,
evaluation configurations
device: torch.device,
device for evaluation
debug: bool, default True,
more detailed evaluation output
logger: Logger, optional,
logger to record detailed evaluation output,
if is None, detailed evaluation output will be printed
Returns
-------
eval_res: tuple of float,
evaluation results, including
auroc, auprc, accuracy, f_measure, f_beta_measure, g_beta_measure
"""
model.eval()
# data_loader.dataset.disable_data_augmentation()
if type(model).__name__ in [
"DataParallel",
]: # TODO: further consider "DistributedDataParallel"
_model = model.module
else:
_model = model
all_scalar_preds = []
all_bin_preds = []
all_labels = []
for signals, labels in data_loader:
signals = signals.to(device=device, dtype=_DTYPE)
labels = labels.numpy()
all_labels.append(labels)
if torch.cuda.is_available():
torch.cuda.synchronize()
model_output = _model.inference(signals)
all_scalar_preds.append(model_output.prob)
all_bin_preds.append(model_output.pred)
all_scalar_preds = np.concatenate(all_scalar_preds, axis=0)
all_bin_preds = np.concatenate(all_bin_preds, axis=0)
all_labels = np.concatenate(all_labels, axis=0)
classes = data_loader.dataset.all_classes
if debug:
msg = f"all_scalar_preds.shape = {all_scalar_preds.shape}, all_labels.shape = {all_labels.shape}"
if logger:
logger.info(msg)
else:
print(msg)
head_num = 5
head_scalar_preds = all_scalar_preds[:head_num, ...]
head_bin_preds = all_bin_preds[:head_num, ...]
head_preds_classes = [np.array(classes)[np.where(row)] for row in head_bin_preds]
head_labels = all_labels[:head_num, ...]
head_labels_classes = [np.array(classes)[np.where(row)] for row in head_labels]
for n in range(head_num):
msg = textwrap.dedent(
f"""
----------------------------------------------
scalar prediction: {[round(n, 3) for n in head_scalar_preds[n].tolist()]}
binary prediction: {head_bin_preds[n].tolist()}
labels: {head_labels[n].astype(int).tolist()}
predicted classes: {head_preds_classes[n].tolist()}
label classes: {head_labels_classes[n].tolist()}
----------------------------------------------
"""
)
if logger:
logger.info(msg)
else:
print(msg)
auroc, auprc, accuracy, f_measure, f_beta_measure, g_beta_measure = eval_score(
classes=classes,
truth=all_labels,
scalar_pred=all_scalar_preds,
binary_pred=all_bin_preds,
)
eval_res = auroc, auprc, accuracy, f_measure, f_beta_measure, g_beta_measure
model.train()
return eval_res
@torch.no_grad()
def evaluate_seq_lab(
model: nn.Module,
data_loader: DataLoader,
config: dict,
device: torch.device,
debug: bool = True,
logger: Optional[logging.Logger] = None,
) -> Dict[str, int]:
"""
Parameters
----------
model: Module,
the model to evaluate
data_loader: DataLoader,
the data loader for loading data for evaluation
config: dict,
evaluation configurations
device: torch.device,
device for evaluation
debug: bool, default True,
more detailed evaluation output
logger: Logger, optional,
logger to record detailed evaluation output,
if is None, detailed evaluation output will be printed
Returns
-------
eval_res: tuple of float,
evaluation results, including
CAUTION
-------
without rpeaks detection, consecutive SPBs or consecutive PVCs might be falsely missed,
hence resulting higher than normal false negatives.
for a more suitable eval pipeline, ref. CPSC2020_challenge.py
"""
model.eval()
# data_loader.dataset.disable_data_augmentation()
if type(model).__name__ in [
"DataParallel",
]: # TODO: further consider "DistributedDataParallel"
_model = model.module
else:
_model = model
all_scalar_preds = []
all_spb_preds = []
all_pvc_preds = []
all_spb_labels = []
all_pvc_labels = []
for signals, labels in data_loader:
signals = signals.to(device=device, dtype=_DTYPE)
labels = labels.numpy() # (batch_size, seq_len, 2 or 3)
spb_intervals = [mask_to_intervals(seq, 1) for seq in labels[..., config.classes.index("S")]]
# print(spb_intervals)
spb_labels = [
[_model.reduction * (itv[0] + itv[1]) // 2 for itv in l_itv] if len(l_itv) > 0 else [] for l_itv in spb_intervals
]
# print(spb_labels)
all_spb_labels.append(spb_labels)
pvc_intervals = [mask_to_intervals(seq, 1) for seq in labels[..., config.classes.index("V")]]
pvc_labels = [
[_model.reduction * (itv[0] + itv[1]) // 2 for itv in l_itv] if len(l_itv) > 0 else [] for l_itv in pvc_intervals
]
all_pvc_labels.append(pvc_labels)
if torch.cuda.is_available():
torch.cuda.synchronize()
model_output = _model.inference(signals)
all_scalar_preds.append(model_output.prob)
all_spb_preds.append(model_output.SPB_indices)
all_pvc_preds.append(model_output.PVC_indices)
all_scalar_preds = np.concatenate(all_scalar_preds, axis=0)
# all_spb_preds = np.concatenate(all_spb_preds, axis=0)
# all_pvc_preds = np.concatenate(all_pvc_preds, axis=0)
# all_spb_labels = np.concatenate(all_spb_labels, axis=0)
# all_pvc_labels = np.concatenate(all_pvc_labels, axis=0)
all_spb_preds = [np.array(item) for item in list_sum(all_spb_preds)]
all_pvc_preds = [np.array(item) for item in list_sum(all_pvc_preds)]
all_spb_labels = [np.array(item) for item in list_sum(all_spb_labels)]
all_pvc_labels = [np.array(item) for item in list_sum(all_pvc_labels)]
eval_res_tmp = CFG(
CPSC2020_score(
spb_true=all_spb_labels,
pvc_true=all_pvc_labels,
spb_pred=all_spb_preds,
pvc_pred=all_pvc_preds,
verbose=1,
)
)
eval_res = CFG(
total_loss=eval_res_tmp.total_loss,
spb_loss=eval_res_tmp.class_loss.S,
pvc_loss=eval_res_tmp.class_loss.V,
spb_tp=eval_res_tmp.true_positive.S,
pvc_tp=eval_res_tmp.true_positive.V,
spb_fp=eval_res_tmp.false_positive.S,
pvc_fp=eval_res_tmp.false_positive.V,
spb_fn=eval_res_tmp.false_negative.S,
pvc_fn=eval_res_tmp.false_negative.V,
)
model.train()
return eval_res
def get_args(**kwargs):
""" """
cfg = deepcopy(kwargs)
parser = argparse.ArgumentParser(
description="Train the Model on CPSC2020",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# parser.add_argument(
# "-l", "--learning-rate",
# metavar="LR", type=float, nargs="?", default=0.001,
# help="Learning rate",
# dest="learning_rate")
parser.add_argument(
"-b",
"--batch-size",
type=int,
default=128,
help="the batch size for training",
dest="batch_size",
)
parser.add_argument(
"-m",
"--model-name",
type=str,
default="crnn",
help="name of the model to train",
dest="model_name",
)
parser.add_argument(
"-c",
"--cnn-name",
type=str,
default="multi_scopic",
help="choice of cnn feature extractor",
dest="cnn_name",
)
parser.add_argument(
"-r",
"--rnn-name",
type=str,
default="linear",
help="choice of rnn structures",
dest="rnn_name",
)
parser.add_argument(
"--keep-checkpoint-max",
type=int,
default=50,
help="maximum number of checkpoints to keep. If set 0, all checkpoints will be kept",
dest="keep_checkpoint_max",
)
parser.add_argument(
"--optimizer",
type=str,
default="adam",
help="training optimizer",
dest="train_optimizer",
)
parser.add_argument(
"--debug",
type=str2bool,
default=False,
help="train with more debugging information",
dest="debug",
)
args = vars(parser.parse_args())
cfg.update(args)
return CFG(cfg)
if __name__ == "__main__":
from utils import init_logger
train_config = get_args(**TrainCfg)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# classes = train_config.classes
model_name = train_config.model_name.lower()
classes = deepcopy(ModelCfg[model_name].classes)
class_map = deepcopy(ModelCfg[model_name].class_map)
if model_name == "crnn":
model_config = deepcopy(ModelCfg.crnn)
elif model_name == "seq_lab":
model_config = deepcopy(ModelCfg.seq_lab)
train_config.classes = deepcopy(model_config.classes)
train_config.class_map = deepcopy(model_config.class_map)
model_config.model_name = model_name
model_config.cnn.name = train_config.cnn_name
model_config.rnn.name = train_config.rnn_name
if model_name == "crnn":
# model = ECG_CRNN(
model = ECG_CRNN_CPSC2020(
classes=classes,
n_leads=train_config.n_leads,
input_len=train_config.input_len,
config=model_config,
)
elif model_name == "seq_lab":
model = ECG_SEQ_LAB_NET_CPSC2020(
classes=classes,
n_leads=train_config.n_leads,
input_len=train_config.input_len,
config=model_config,
)
else:
raise NotImplementedError(f"Model {model_name} not supported yet!")
if torch.cuda.device_count() > 1:
model = DP(model)
# model = DDP(model)
model.to(device=device)
model.__DEBUG__ = False
logger = init_logger(log_dir=str(train_config.log_dir), verbose=2)
logger.info(f"\n{'*'*20} Start Training {'*'*20}\n")
logger.info(f"Model name = {train_config.model_name}")
logger.info(f"Using device {device}")
logger.info(f"Using torch of version {torch.__version__}")
logger.info(f"with configuration\n{dict_to_str(train_config)}")
# print(f"\n{'*'*20} Start Training {'*'*20}\n")
# print(f"Using device {device}")
# print(f"Using torch of version {torch.__version__}")
# print(f"with configuration\n{dict_to_str(train_config)}")
try:
train(
model=model,
model_config=model_config,
config=train_config,
device=device,
logger=logger,
debug=train_config.debug,
)
except KeyboardInterrupt:
torch.save(
{
"model_state_dict": model.state_dict(),
"model_config": model_config,
"train_config": train_config,
},
str(train_config.checkpoints / "INTERRUPTED.pth.tar"),
)
logger.info("Saved interrupt")
try:
sys.exit(0)
except SystemExit:
os._exit(0)