Switch to unified view

a b/pipelines/base_pipeline.py
1
import json
2
import os
3
import os.path as osp
4
from datetime import datetime
5
6
import numpy as np
7
import plotly.graph_objects as go
8
import torch
9
import wfdb
10
from tqdm import tqdm
11
12
from utils.network_utils import load_checkpoint
13
14
15
class BasePipeline:
16
    def __init__(self, config):
17
        self.config = config
18
        self.exp_name = self.config.get("exp_name", None)
19
        if self.exp_name is None:
20
            self.exp_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
21
22
        self.res_dir = osp.join(self.config["exp_dir"], self.exp_name, "results")
23
        os.makedirs(self.res_dir, exist_ok=True)
24
25
        self.model = self._init_net()
26
27
        self.pipeline_loader = self._init_dataloader()
28
29
        self.mapper = json.load(open(config["mapping_json"]))
30
        self.mapper = {j: i for i, j in self.mapper.items()}
31
32
        pretrained_path = self.config.get("model_path", False)
33
        if pretrained_path:
34
            load_checkpoint(pretrained_path, self.model)
35
        else:
36
            raise Exception(
37
                "model_path doesnt't exist in config. Please specify checkpoint path",
38
            )
39
40
    def _init_net(self):
41
        raise NotImplemented
42
43
    def _init_dataloader(self):
44
        raise NotImplemented
45
46
    def run_pipeline(self):
47
        self.model.eval()
48
        pd_class = np.empty(0)
49
        pd_peaks = np.empty(0)
50
51
        with torch.no_grad():
52
            for i, batch in tqdm(enumerate(self.pipeline_loader)):
53
                inputs = batch["image"].to(self.config["device"])
54
55
                predictions = self.model(inputs)
56
57
                classes = predictions.topk(k=1)[1].view(-1).cpu().numpy()
58
59
                pd_class = np.concatenate((pd_class, classes))
60
                pd_peaks = np.concatenate((pd_peaks, batch["peak"]))
61
62
        pd_class = pd_class.astype(int)
63
        pd_peaks = pd_peaks.astype(int)
64
65
        annotations = []
66
        for label, peak in zip(pd_class, pd_peaks):
67
            if (
68
                peak < len(self.pipeline_loader.dataset.signal)
69
                and self.mapper[label] != "N"
70
            ):
71
                annotations.append(
72
                    {
73
                        "x": peak,
74
                        "y": self.pipeline_loader.dataset.signal[peak],
75
                        "text": self.mapper[label],
76
                        "xref": "x",
77
                        "yref": "y",
78
                        "showarrow": True,
79
                        "arrowcolor": "black",
80
                        "arrowhead": 1,
81
                        "arrowsize": 2,
82
                    },
83
                )
84
85
        if osp.exists(self.config["ecg_data"] + ".atr"):
86
            ann = wfdb.rdann(self.config["ecg_data"], extension="atr")
87
            for label, peak in zip(ann.symbol, ann.sample):
88
                if peak < len(self.pipeline_loader.dataset.signal) and label != "N":
89
                    annotations.append(
90
                        {
91
                            "x": peak,
92
                            "y": self.pipeline_loader.dataset.signal[peak] - 0.1,
93
                            "text": label,
94
                            "xref": "x",
95
                            "yref": "y",
96
                            "showarrow": False,
97
                            "bordercolor": "#c7c7c7",
98
                            "borderwidth": 1,
99
                            "borderpad": 4,
100
                            "bgcolor": "#ffffff",
101
                            "opacity": 1,
102
                        },
103
                    )
104
105
        fig = go.Figure(
106
            data=go.Scatter(
107
                x=list(range(len(self.pipeline_loader.dataset.signal))),
108
                y=self.pipeline_loader.dataset.signal,
109
            ),
110
        )
111
        fig.update_layout(
112
            title="ECG",
113
            xaxis_title="Time",
114
            yaxis_title="ECG Output Value",
115
            title_x=0.5,
116
            annotations=annotations,
117
            autosize=True,
118
        )
119
120
        fig.write_html(
121
            osp.join(self.res_dir, osp.basename(self.config["ecg_data"] + ".html")),
122
        )