--- a +++ b/model_ecg.py @@ -0,0 +1,250 @@ +from tkinter import Y +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.functional import relu + +from models.pointnet_utils import PointNetEncoder +from models.pointnet2_utils import PointNetSetAbstraction,PointNetFeaturePropagation + +class ECGnet(nn.Module): + def __init__(self, in_ch=3+4, out_ch=3, num_input=1024, z_dims=16): + super(ECGnet, self).__init__() + + + self.encoder_signal = CRNN() + + # decode for signal + self.elu = nn.ELU(inplace=True) + self.fc1 = nn.Linear(z_dims, 256*2) + self.fc2 = nn.Linear(256*2, 512*2) + self.up = nn.Upsample(size=(8, 512), mode='bilinear') + self.deconv = DoubleDeConv(1, 1) + + self.decoder_MI = nn.Sequential( + nn.Linear(z_dims, 128), + nn.ReLU(), + nn.Linear(128, 64), + nn.ReLU(), + nn.Linear(64, out_ch), + ) + + + def reparameterize(self, mu, log_var): + """ + :param mu: mean from the encoder's latent space + :param log_var: log variance from the encoder's latent space + """ + std = torch.exp(0.5*log_var) # standard deviation + eps = torch.randn(log_var.shape).to(std.device) # `randn_like` as we need the same size + sample = mu + (eps * std) # sampling as if coming from the input space + return sample + + def decode_signal(self, latent_z): # P(x|z, c) + ''' + z: (bs, latent_size) + ''' + inputs = latent_z + f = self.elu(self.fc1(inputs)) + f = self.elu(self.fc2(f)) + u = self.up(f.reshape(f.shape[0], 1, 8, -1)) + dc = self.deconv(u) + + return dc + + def forward(self, partial_input, signal_input): + + mu_signal, std_signal = self.encoder_signal(signal_input) + latent_z_signal = self.reparameterize(mu_signal, std_signal) + y_ECG = self.decode_signal(latent_z_signal) + y_MI = self.decoder_MI(latent_z_signal) + y_MI = nn.Softmax(dim=1)(y_MI) + + return y_MI, y_ECG, mu_signal, std_signal + +class InferenceNet(nn.Module): + def __init__(self, in_ch=3+4, out_ch=3, num_input=1024, z_dims=16): + super(InferenceNet, self).__init__() + + self.z_dims = z_dims + + # encode for signal + self.encoder_signal = CRNN() + + # decode for signal + self.elu = nn.ELU(inplace=True) + self.fc1 = nn.Linear(z_dims*2, 256*2) + self.fc2 = nn.Linear(256*2, 512*2) + self.up = nn.Upsample(size=(8, 512), mode='bilinear') + self.deconv = DoubleDeConv(1, 1) + + def reparameterize(self, mu, log_var): + """ + :param mu: mean from the encoder's latent space + :param log_var: log variance from the encoder's latent space + """ + std = torch.exp(0.5*log_var) # standard deviation + eps = torch.randn(log_var.shape).to(std.device) # `randn_like` as we need the same size + sample = mu + (eps * std) # sampling as if coming from the input space + return sample + + def decode_signal(self, latent_z): # P(x|z, c) + ''' + z: (bs, latent_size) + ''' + inputs = latent_z + f = self.elu(self.fc1(inputs)) + f = self.elu(self.fc2(f)) + u = self.up(f.reshape(f.shape[0], 1, 8, -1)) + dc = self.deconv(u) + + return dc + + def forward(self, partial_input, signal_input): + num_points = partial_input.shape[-1] + # extract ecg features + mu_signal, std_signal = self.encoder_signal(signal_input) + # latent_z_signal = self.reparameterize(mu_signal, std_signal) + + + # fuse two features + mu = torch.cat((mu_geometry, mu_signal), dim=1) + log_var = torch.cat((std_geometry, std_signal), dim=1) + latent_z = self.reparameterize(mu, log_var) + + + y_ECG = self.decode_signal(latent_z) + + return y_seg, y_coarse, y_detail, y_ECG, mu, log_var + +class CRNN(nn.Module): + ''' + nh: default=256, 'size of the LSTM hidden state' + imgH: default=8, 'the height of the input image to network' + imgW: default=256, 'the width of the input image to network' + + :param class_labels: list[n_class] + :return: (n_batch, n_class) + ''' + + def __init__(self, n_lead=8, z_dims=16): + super(CRNN, self).__init__() + + n_out = 128 + self.z_dims = z_dims + + self.cnn = nn.Sequential( + nn.Conv1d(n_lead, n_out, kernel_size=16, stride=2, padding=2), + nn.BatchNorm1d(n_out), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(n_out, n_out*2, kernel_size=16, stride=2, padding=2), + nn.BatchNorm1d(n_out*2), + nn.LeakyReLU(0.2, inplace=True) + ) + + + self.rnn = BidirectionalLSTM(256, z_dims*4, z_dims*2) + # self.rnn = nn.Sequential( + # BidirectionalLSTM(512, nh, nh), + # BidirectionalLSTM(nh, nh, 1)) + + + def forward(self, input): + # conv features + conv = self.cnn(input) + b, c, w = conv.size() + conv = conv.permute(2, 0, 1) # [w, b, c] + + # rnn features + output = self.rnn(conv).permute(1, 0, 2) + features = torch.max(output, 1)[0] + mean = features[:, : self.z_dims] + std = features[:, self.z_dims:] + 1e-6 + + return mean, std + + + def backward_hook(self, module, grad_input, grad_output): + for g in grad_input: + g[g != g] = 0 # replace all nan/inf in gradients to zero + +class BidirectionalLSTM(nn.Module): + + def __init__(self, nIn, nHidden, nOut): + super(BidirectionalLSTM, self).__init__() + + self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True) + self.embedding = nn.Linear(nHidden * 2, nOut) + + def forward(self, input): + recurrent, _ = self.rnn(input) + T, b, h = recurrent.size() + t_rec = recurrent.view(T * b, h) + + output = self.embedding(t_rec) # [T * b, nOut] + output = output.view(T, b, -1) + + return output + +class DoubleDeConv(nn.Module): + def __init__(self, in_ch, out_ch): + super(DoubleDeConv, self).__init__() + self.conv = nn.Sequential( + nn.ConvTranspose2d(in_ch, out_ch, kernel_size=(3, 3), padding=1), + nn.BatchNorm2d(out_ch), + nn.ELU(inplace=True), + nn.ConvTranspose2d(out_ch, out_ch, kernel_size=(3, 3), padding=1), + nn.BatchNorm2d(out_ch), + nn.ELU(inplace=True) + ) + + def forward(self, input): + return self.conv(input) + +def dtw_loss(ecg1, ecg2): # to do: plot the curve of x-y axis. + """ + 计算两个ECG序列之间的Dynamic Time Warping(DTW)损失。 + + 参数: + - ecg1: 第一个ECG序列,形状为 (batch_size, seq_len1, num_features) + - ecg2: 第二个ECG序列,形状为 (batch_size, seq_len2, num_features) + + 返回: + - dtw_loss: DTW损失,标量张量 + """ + batch_size, seq_len1, num_features = ecg1.size() + _, seq_len2, _ = ecg2.size() + + # 计算两个ECG序列之间的距离矩阵 + distance_matrix = torch.cdist(ecg1, ecg2) # 形状为 (batch_size, seq_len1, seq_len2) + + # 初始化动态规划表格 + torch.autograd.set_detect_anomaly(True) + dp = torch.zeros((batch_size, seq_len1, seq_len2)).to(ecg1.device) + + # 填充动态规划表格 + dp[:, 0, 0] = distance_matrix[:, 0, 0] + for i in range(1, seq_len1): + dp[:, i, 0] = distance_matrix[:, i, 0] + dp[:, i-1, 0].clone() + for j in range(1, seq_len2): + dp[:, 0, j] = distance_matrix[:, 0, j] + dp[:, 0, j-1].clone() + for i in range(1, seq_len1): + for j in range(1, seq_len2): + dp[:, i, j] = distance_matrix[:, i, j] + torch.min(torch.stack([ + dp[:, i-1, j].clone(), + dp[:, i, j-1].clone(), + dp[:, i-1, j-1].clone() + ], dim=1), dim=1).values + + dtw_loss = torch.mean(dp[:, seq_len1-1, seq_len2-1] / (seq_len1 + seq_len2)) + + return dtw_loss + +if __name__ == "__main__": + x = torch.rand(3, 4, 2048) + conditions = torch.rand(3, 2, 1) + + network = BetaVAE() + y_coarse, y_detail = network(x, conditions) + print(y_coarse.size(), y_detail.size())