|
a |
|
b/pipelines/base_pipeline.py |
|
|
1 |
import json |
|
|
2 |
import os |
|
|
3 |
import os.path as osp |
|
|
4 |
from datetime import datetime |
|
|
5 |
|
|
|
6 |
import numpy as np |
|
|
7 |
import plotly.graph_objects as go |
|
|
8 |
import torch |
|
|
9 |
import wfdb |
|
|
10 |
from tqdm import tqdm |
|
|
11 |
|
|
|
12 |
from utils.network_utils import load_checkpoint |
|
|
13 |
|
|
|
14 |
|
|
|
15 |
class BasePipeline: |
|
|
16 |
def __init__(self, config): |
|
|
17 |
self.config = config |
|
|
18 |
self.exp_name = self.config.get("exp_name", None) |
|
|
19 |
if self.exp_name is None: |
|
|
20 |
self.exp_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
|
|
21 |
|
|
|
22 |
self.res_dir = osp.join(self.config["exp_dir"], self.exp_name, "results") |
|
|
23 |
os.makedirs(self.res_dir, exist_ok=True) |
|
|
24 |
|
|
|
25 |
self.model = self._init_net() |
|
|
26 |
|
|
|
27 |
self.pipeline_loader = self._init_dataloader() |
|
|
28 |
|
|
|
29 |
self.mapper = json.load(open(config["mapping_json"])) |
|
|
30 |
self.mapper = {j: i for i, j in self.mapper.items()} |
|
|
31 |
|
|
|
32 |
pretrained_path = self.config.get("model_path", False) |
|
|
33 |
if pretrained_path: |
|
|
34 |
load_checkpoint(pretrained_path, self.model) |
|
|
35 |
else: |
|
|
36 |
raise Exception( |
|
|
37 |
"model_path doesnt't exist in config. Please specify checkpoint path", |
|
|
38 |
) |
|
|
39 |
|
|
|
40 |
def _init_net(self): |
|
|
41 |
raise NotImplemented |
|
|
42 |
|
|
|
43 |
def _init_dataloader(self): |
|
|
44 |
raise NotImplemented |
|
|
45 |
|
|
|
46 |
def run_pipeline(self): |
|
|
47 |
self.model.eval() |
|
|
48 |
pd_class = np.empty(0) |
|
|
49 |
pd_peaks = np.empty(0) |
|
|
50 |
|
|
|
51 |
with torch.no_grad(): |
|
|
52 |
for i, batch in tqdm(enumerate(self.pipeline_loader)): |
|
|
53 |
inputs = batch["image"].to(self.config["device"]) |
|
|
54 |
|
|
|
55 |
predictions = self.model(inputs) |
|
|
56 |
|
|
|
57 |
classes = predictions.topk(k=1)[1].view(-1).cpu().numpy() |
|
|
58 |
|
|
|
59 |
pd_class = np.concatenate((pd_class, classes)) |
|
|
60 |
pd_peaks = np.concatenate((pd_peaks, batch["peak"])) |
|
|
61 |
|
|
|
62 |
pd_class = pd_class.astype(int) |
|
|
63 |
pd_peaks = pd_peaks.astype(int) |
|
|
64 |
|
|
|
65 |
annotations = [] |
|
|
66 |
for label, peak in zip(pd_class, pd_peaks): |
|
|
67 |
if ( |
|
|
68 |
peak < len(self.pipeline_loader.dataset.signal) |
|
|
69 |
and self.mapper[label] != "N" |
|
|
70 |
): |
|
|
71 |
annotations.append( |
|
|
72 |
{ |
|
|
73 |
"x": peak, |
|
|
74 |
"y": self.pipeline_loader.dataset.signal[peak], |
|
|
75 |
"text": self.mapper[label], |
|
|
76 |
"xref": "x", |
|
|
77 |
"yref": "y", |
|
|
78 |
"showarrow": True, |
|
|
79 |
"arrowcolor": "black", |
|
|
80 |
"arrowhead": 1, |
|
|
81 |
"arrowsize": 2, |
|
|
82 |
}, |
|
|
83 |
) |
|
|
84 |
|
|
|
85 |
if osp.exists(self.config["ecg_data"] + ".atr"): |
|
|
86 |
ann = wfdb.rdann(self.config["ecg_data"], extension="atr") |
|
|
87 |
for label, peak in zip(ann.symbol, ann.sample): |
|
|
88 |
if peak < len(self.pipeline_loader.dataset.signal) and label != "N": |
|
|
89 |
annotations.append( |
|
|
90 |
{ |
|
|
91 |
"x": peak, |
|
|
92 |
"y": self.pipeline_loader.dataset.signal[peak] - 0.1, |
|
|
93 |
"text": label, |
|
|
94 |
"xref": "x", |
|
|
95 |
"yref": "y", |
|
|
96 |
"showarrow": False, |
|
|
97 |
"bordercolor": "#c7c7c7", |
|
|
98 |
"borderwidth": 1, |
|
|
99 |
"borderpad": 4, |
|
|
100 |
"bgcolor": "#ffffff", |
|
|
101 |
"opacity": 1, |
|
|
102 |
}, |
|
|
103 |
) |
|
|
104 |
|
|
|
105 |
fig = go.Figure( |
|
|
106 |
data=go.Scatter( |
|
|
107 |
x=list(range(len(self.pipeline_loader.dataset.signal))), |
|
|
108 |
y=self.pipeline_loader.dataset.signal, |
|
|
109 |
), |
|
|
110 |
) |
|
|
111 |
fig.update_layout( |
|
|
112 |
title="ECG", |
|
|
113 |
xaxis_title="Time", |
|
|
114 |
yaxis_title="ECG Output Value", |
|
|
115 |
title_x=0.5, |
|
|
116 |
annotations=annotations, |
|
|
117 |
autosize=True, |
|
|
118 |
) |
|
|
119 |
|
|
|
120 |
fig.write_html( |
|
|
121 |
osp.join(self.res_dir, osp.basename(self.config["ecg_data"] + ".html")), |
|
|
122 |
) |