[6bf179]: / ecg_gan / gan.py

Download this file

54 lines (46 with data), 1.5 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fc1 = nn.Linear(256, 256)
self.fc2 = nn.Linear(256, 256)
self.fc3 = nn.Linear(256, 187)
self.rnn_layer = nn.LSTM(
input_size=187,
hidden_size=128,
num_layers=1,
bidirectional=True,
batch_first=True,
)
def forward(self, x):
x,_ = self.rnn_layer(x)
x = x.view(-1,256)
x = F.leaky_relu(self.fc1(x))
x = F.leaky_relu(self.fc2(x))
x = F.dropout(x, p=0.2)
x = self.fc3(x)
return x.unsqueeze(1)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.rnn_layer = nn.LSTM(
input_size=187,
hidden_size=256,
num_layers=1,
bidirectional=True,
batch_first=True,
)
self.fc1 = nn.Linear(512, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 1)
def forward(self, x):
x,_ = self.rnn_layer(x)
x = x.view(-1, 512)
x = F.leaky_relu(self.fc1(x))
x = F.leaky_relu(self.fc2(x))
x = F.dropout(x, p=0.2)
x = torch.sigmoid(self.fc3(x))
return x