[c128d9]: / pipelines / base_pipeline.py

Download this file

123 lines (101 with data), 4.0 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import json
import os
import os.path as osp
from datetime import datetime
import numpy as np
import plotly.graph_objects as go
import torch
import wfdb
from tqdm import tqdm
from utils.network_utils import load_checkpoint
class BasePipeline:
def __init__(self, config):
self.config = config
self.exp_name = self.config.get("exp_name", None)
if self.exp_name is None:
self.exp_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
self.res_dir = osp.join(self.config["exp_dir"], self.exp_name, "results")
os.makedirs(self.res_dir, exist_ok=True)
self.model = self._init_net()
self.pipeline_loader = self._init_dataloader()
self.mapper = json.load(open(config["mapping_json"]))
self.mapper = {j: i for i, j in self.mapper.items()}
pretrained_path = self.config.get("model_path", False)
if pretrained_path:
load_checkpoint(pretrained_path, self.model)
else:
raise Exception(
"model_path doesnt't exist in config. Please specify checkpoint path",
)
def _init_net(self):
raise NotImplemented
def _init_dataloader(self):
raise NotImplemented
def run_pipeline(self):
self.model.eval()
pd_class = np.empty(0)
pd_peaks = np.empty(0)
with torch.no_grad():
for i, batch in tqdm(enumerate(self.pipeline_loader)):
inputs = batch["image"].to(self.config["device"])
predictions = self.model(inputs)
classes = predictions.topk(k=1)[1].view(-1).cpu().numpy()
pd_class = np.concatenate((pd_class, classes))
pd_peaks = np.concatenate((pd_peaks, batch["peak"]))
pd_class = pd_class.astype(int)
pd_peaks = pd_peaks.astype(int)
annotations = []
for label, peak in zip(pd_class, pd_peaks):
if (
peak < len(self.pipeline_loader.dataset.signal)
and self.mapper[label] != "N"
):
annotations.append(
{
"x": peak,
"y": self.pipeline_loader.dataset.signal[peak],
"text": self.mapper[label],
"xref": "x",
"yref": "y",
"showarrow": True,
"arrowcolor": "black",
"arrowhead": 1,
"arrowsize": 2,
},
)
if osp.exists(self.config["ecg_data"] + ".atr"):
ann = wfdb.rdann(self.config["ecg_data"], extension="atr")
for label, peak in zip(ann.symbol, ann.sample):
if peak < len(self.pipeline_loader.dataset.signal) and label != "N":
annotations.append(
{
"x": peak,
"y": self.pipeline_loader.dataset.signal[peak] - 0.1,
"text": label,
"xref": "x",
"yref": "y",
"showarrow": False,
"bordercolor": "#c7c7c7",
"borderwidth": 1,
"borderpad": 4,
"bgcolor": "#ffffff",
"opacity": 1,
},
)
fig = go.Figure(
data=go.Scatter(
x=list(range(len(self.pipeline_loader.dataset.signal))),
y=self.pipeline_loader.dataset.signal,
),
)
fig.update_layout(
title="ECG",
xaxis_title="Time",
yaxis_title="ECG Output Value",
title_x=0.5,
annotations=annotations,
autosize=True,
)
fig.write_html(
osp.join(self.res_dir, osp.basename(self.config["ecg_data"] + ".html")),
)