|
a |
|
b/train.py |
|
|
1 |
import argparse |
|
|
2 |
import torch |
|
|
3 |
torch.cuda.empty_cache() # clearing the occupied cuda memory |
|
|
4 |
from torch.backends import cudnn |
|
|
5 |
import torch.optim as optim |
|
|
6 |
from torch.utils.data import DataLoader |
|
|
7 |
import os |
|
|
8 |
import numpy as np |
|
|
9 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256" |
|
|
10 |
|
|
|
11 |
|
|
|
12 |
from dataset import LoadDataset |
|
|
13 |
from model import InferenceNet, ECGnet |
|
|
14 |
from loss import calculate_inference_loss, calculate_reconstruction_loss, calculate_ECG_reconstruction_loss, calculate_classify_loss |
|
|
15 |
from utils import lossplot, lossplot_detailed, visualize_PC_with_label, ECG_visual_two, lossplot_classify, visualize_PC_with_twolabel |
|
|
16 |
|
|
|
17 |
def train_ecg(args): |
|
|
18 |
DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') |
|
|
19 |
# DEVICE = torch.device('cpu') |
|
|
20 |
train_dataset = LoadDataset(path=args.partial_root, num_input=args.num_input, split='train') |
|
|
21 |
val_dataset = LoadDataset(path=args.partial_root, num_input=args.num_input, split='val') |
|
|
22 |
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) |
|
|
23 |
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) |
|
|
24 |
cudnn.benchmark = True |
|
|
25 |
|
|
|
26 |
network = InferenceNet(in_ch=args.in_ch, out_ch=args.out_ch, num_input=args.num_input, z_dims=args.z_dims) |
|
|
27 |
|
|
|
28 |
if args.model is not None: |
|
|
29 |
print('Loaded trained model from {}.'.format(args.model)) |
|
|
30 |
network.load_state_dict(torch.load(args.model)) |
|
|
31 |
else: |
|
|
32 |
print('Begin training new model.') |
|
|
33 |
|
|
|
34 |
network.to(DEVICE) |
|
|
35 |
optimizer = optim.Adam(network.parameters(), lr=args.base_lr, weight_decay=args.weight_decay) |
|
|
36 |
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_decay_steps, gamma=args.lr_decay_rate) |
|
|
37 |
|
|
|
38 |
max_iter = int(len(train_dataset) / args.batch_size + 0.5) |
|
|
39 |
minimum_loss = 1e4 |
|
|
40 |
best_epoch = 0 |
|
|
41 |
|
|
|
42 |
lossfile_train = args.log_dir + "/training_loss.txt" |
|
|
43 |
lossfile_val = args.log_dir + "/val_loss.txt" |
|
|
44 |
lossfile_geometry_train = args.log_dir + "/training_calculate_inference_loss.txt" |
|
|
45 |
lossfile_geometry_val = args.log_dir + "/val_calculate_inference_loss.txt" |
|
|
46 |
lossfile_KL_train = args.log_dir + "/training_KL_loss.txt" |
|
|
47 |
lossfile_KL_val = args.log_dir + "/val_KL_loss.txt" |
|
|
48 |
lossfile_ecg_train = args.log_dir + "/training_ecg_loss.txt" |
|
|
49 |
lossfile_ecg_val = args.log_dir + "/val_ecg_loss.txt" |
|
|
50 |
|
|
|
51 |
|
|
|
52 |
for epoch in range(1, args.epochs + 1): |
|
|
53 |
if ((epoch % 25) == 0) and (epoch != 0): |
|
|
54 |
lossplot_classify(lossfile_train, lossfile_val, lossfile_geometry_train, lossfile_geometry_val, lossfile_KL_train, lossfile_KL_val, lossfile_ecg_train, lossfile_ecg_val) |
|
|
55 |
|
|
|
56 |
f_train = open(lossfile_train, 'a') # a: additional writing; w: overwrite writing |
|
|
57 |
f_val = open(lossfile_val, 'a') |
|
|
58 |
f_MI_train = open(lossfile_geometry_train, 'a') # a: additional writing; w: overwrite writing |
|
|
59 |
f_MI_val = open(lossfile_geometry_val, 'a') |
|
|
60 |
f_KL_train = open(lossfile_KL_train, 'a') # a: additional writing; w: overwrite writing |
|
|
61 |
f_KL_val = open(lossfile_KL_val, 'a') |
|
|
62 |
f_ecg_train = open(lossfile_ecg_train, 'a') # a: additional writing; w: overwrite writing |
|
|
63 |
f_ecg_val = open(lossfile_ecg_val, 'a') |
|
|
64 |
|
|
|
65 |
# if ((epoch % 25) == 0) and (epoch != 0): |
|
|
66 |
# if lamda_KL < 1: |
|
|
67 |
# lamda_KL = 0.1*epoch*lamda_KL # 0.25 |
|
|
68 |
# else: |
|
|
69 |
# lamda_KL = 0.1 |
|
|
70 |
|
|
|
71 |
# training |
|
|
72 |
network.train() |
|
|
73 |
total_loss, iter_count = 0, 0 |
|
|
74 |
for i, data in enumerate(train_dataloader, 1): |
|
|
75 |
partial_input, ECG_input, gt_MI, partial_input_coarse = data |
|
|
76 |
partial_input, ECG_input, gt_MI = partial_input.to(DEVICE), ECG_input.to(DEVICE), gt_MI.to(DEVICE) |
|
|
77 |
partial_input_coarse = partial_input_coarse.to(DEVICE) |
|
|
78 |
partial_input = partial_input.permute(0, 2, 1) |
|
|
79 |
|
|
|
80 |
optimizer.zero_grad() |
|
|
81 |
|
|
|
82 |
y_MI, y_ECG, mu, log_var = network(partial_input, ECG_input) |
|
|
83 |
|
|
|
84 |
loss_seg, KL_loss = calculate_classify_loss(y_MI, gt_MI, mu, log_var) |
|
|
85 |
loss_signal = calculate_ECG_reconstruction_loss(y_ECG, ECG_input) |
|
|
86 |
loss = loss_seg + args.lamda_KL*KL_loss |
|
|
87 |
|
|
|
88 |
check_grad = False |
|
|
89 |
if check_grad: |
|
|
90 |
print(loss_seg) |
|
|
91 |
print(loss_signal) |
|
|
92 |
print(KL_loss) |
|
|
93 |
|
|
|
94 |
print(loss.requires_grad) |
|
|
95 |
print(loss_seg.requires_grad) |
|
|
96 |
print(KL_loss.requires_grad) |
|
|
97 |
print(loss_signal.requires_grad) |
|
|
98 |
|
|
|
99 |
visual_check = False |
|
|
100 |
if visual_check: |
|
|
101 |
gd_ECG = ECG_input[0].cpu().detach().numpy() |
|
|
102 |
y_ECG = y_ECG[0].cpu().detach().numpy() |
|
|
103 |
ECG_visual_two(y_ECG, gd_ECG) |
|
|
104 |
|
|
|
105 |
loss.backward() |
|
|
106 |
optimizer.step() |
|
|
107 |
|
|
|
108 |
f_train.write(str(loss.item())) |
|
|
109 |
f_train.write('\n') |
|
|
110 |
f_MI_train.write(str(loss_seg.item())) |
|
|
111 |
f_MI_train.write('\n') |
|
|
112 |
f_KL_train.write(str(KL_loss.item())) |
|
|
113 |
f_KL_train.write('\n') |
|
|
114 |
f_ecg_train.write(str(loss_signal.item())) |
|
|
115 |
f_ecg_train.write('\n') |
|
|
116 |
|
|
|
117 |
|
|
|
118 |
iter_count += 1 |
|
|
119 |
total_loss += loss.item() |
|
|
120 |
|
|
|
121 |
if i % 50 == 0: |
|
|
122 |
print("Training epoch {}/{}, iteration {}/{}: loss is {}".format(epoch, args.epochs, i, max_iter, loss.item())) |
|
|
123 |
scheduler.step() |
|
|
124 |
|
|
|
125 |
print("\033[96mTraining epoch {}/{}: avg loss = {}\033[0m".format(epoch, args.epochs, total_loss / iter_count)) |
|
|
126 |
|
|
|
127 |
# evaluation |
|
|
128 |
network.eval() |
|
|
129 |
with torch.no_grad(): |
|
|
130 |
total_loss, iter_count = 0, 0 |
|
|
131 |
for i, data in enumerate(val_dataloader, 1): |
|
|
132 |
partial_input, ECG_input, gt_MI, partial_input_coarse = data |
|
|
133 |
partial_input, ECG_input, gt_MI = partial_input.to(DEVICE), ECG_input.to(DEVICE), gt_MI.to(DEVICE) |
|
|
134 |
partial_input_coarse = partial_input_coarse.to(DEVICE) |
|
|
135 |
partial_input = partial_input.permute(0, 2, 1) |
|
|
136 |
|
|
|
137 |
y_MI, y_ECG, mu, log_var = network(partial_input, ECG_input) |
|
|
138 |
|
|
|
139 |
loss_seg, KL_loss = calculate_classify_loss(y_MI, gt_MI, mu, log_var) |
|
|
140 |
loss_signal = calculate_ECG_reconstruction_loss(y_ECG, ECG_input) |
|
|
141 |
loss = loss_seg + args.lamda_KL*KL_loss |
|
|
142 |
|
|
|
143 |
total_loss += loss.item() |
|
|
144 |
iter_count += 1 |
|
|
145 |
|
|
|
146 |
visual_check = False |
|
|
147 |
if visual_check: |
|
|
148 |
gd_ECG = ECG_input[0].cpu().detach().numpy() |
|
|
149 |
y_ECG = y_ECG[0].cpu().detach().numpy() |
|
|
150 |
ECG_visual_two(y_ECG, gd_ECG) |
|
|
151 |
|
|
|
152 |
f_val.write(str(loss.item())) |
|
|
153 |
f_val.write('\n') |
|
|
154 |
f_MI_val.write(str(loss_seg.item())) |
|
|
155 |
f_MI_val.write('\n') |
|
|
156 |
f_KL_val.write(str(KL_loss.item())) |
|
|
157 |
f_KL_val.write('\n') |
|
|
158 |
f_ecg_val.write(str(loss_signal.item())) |
|
|
159 |
f_ecg_val.write('\n') |
|
|
160 |
|
|
|
161 |
|
|
|
162 |
mean_loss = total_loss / iter_count |
|
|
163 |
print("\033[35mValidation epoch {}/{}, loss is {}\033[0m".format(epoch, args.epochs, mean_loss)) |
|
|
164 |
|
|
|
165 |
# records the best model and epoch |
|
|
166 |
if mean_loss < minimum_loss: |
|
|
167 |
best_epoch = epoch |
|
|
168 |
minimum_loss = mean_loss |
|
|
169 |
strNetSaveName = 'net_model_classify.pkl' |
|
|
170 |
# strNetSaveName = 'net_with_%d.pkl' % epoch |
|
|
171 |
torch.save(network.state_dict(), args.log_dir + '/' + strNetSaveName) |
|
|
172 |
|
|
|
173 |
print("\033[4;37mBest model (lowest loss) in epoch {}\033[0m".format(best_epoch)) |
|
|
174 |
|
|
|
175 |
lossplot(lossfile_train, lossfile_val) |
|
|
176 |
|
|
|
177 |
|
|
|
178 |
def train(args): |
|
|
179 |
DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') |
|
|
180 |
# DEVICE = torch.device('cpu') |
|
|
181 |
train_dataset = LoadDataset(path=args.partial_root, num_input=args.num_input, split='train') |
|
|
182 |
val_dataset = LoadDataset(path=args.partial_root, num_input=args.num_input, split='val') |
|
|
183 |
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) |
|
|
184 |
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) |
|
|
185 |
cudnn.benchmark = True |
|
|
186 |
|
|
|
187 |
network = InferenceNet(in_ch=args.in_ch, out_ch=args.out_ch, num_input=args.num_input, z_dims=args.z_dims) |
|
|
188 |
|
|
|
189 |
if args.model is not None: |
|
|
190 |
print('Loaded trained model from {}.'.format(args.model)) |
|
|
191 |
network.load_state_dict(torch.load(args.model)) |
|
|
192 |
else: |
|
|
193 |
print('Begin training new model.') |
|
|
194 |
|
|
|
195 |
network.to(DEVICE) |
|
|
196 |
optimizer = optim.Adam(network.parameters(), lr=args.base_lr, weight_decay=args.weight_decay) |
|
|
197 |
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_decay_steps, gamma=args.lr_decay_rate) |
|
|
198 |
|
|
|
199 |
max_iter = int(len(train_dataset) / args.batch_size + 0.5) |
|
|
200 |
minimum_loss = 1e4 |
|
|
201 |
best_epoch = 0 |
|
|
202 |
|
|
|
203 |
lossfile_train = args.log_dir + "/training_loss.txt" |
|
|
204 |
lossfile_val = args.log_dir + "/val_loss.txt" |
|
|
205 |
lossfile_geometry_train = args.log_dir + "/training_calculate_inference_loss.txt" |
|
|
206 |
lossfile_geometry_val = args.log_dir + "/val_calculate_inference_loss.txt" |
|
|
207 |
lossfile_compactness_train = args.log_dir + "/training_compactness_loss.txt" |
|
|
208 |
lossfile_compactness_val = args.log_dir + "/val_compactness_loss.txt" |
|
|
209 |
lossfile_KL_train = args.log_dir + "/training_KL_loss.txt" |
|
|
210 |
lossfile_KL_val = args.log_dir + "/val_KL_loss.txt" |
|
|
211 |
lossfile_PC_train = args.log_dir + "/training_PC_loss.txt" |
|
|
212 |
lossfile_PC_val = args.log_dir + "/val_PC_loss.txt" |
|
|
213 |
lossfile_ecg_train = args.log_dir + "/training_ecg_loss.txt" |
|
|
214 |
lossfile_ecg_val = args.log_dir + "/val_ecg_loss.txt" |
|
|
215 |
lossfile_RVp_train = args.log_dir + "/training_RVp_loss.txt" |
|
|
216 |
lossfile_RVp_val = args.log_dir + "/val_RVp_loss.txt" |
|
|
217 |
lossfile_size_train = args.log_dir + "/training_MIsize_loss.txt" |
|
|
218 |
lossfile_size_val = args.log_dir + "/val_MIsize_loss.txt" |
|
|
219 |
|
|
|
220 |
lamda_KL = args.lamda_KL |
|
|
221 |
for epoch in range(1, args.epochs + 1): |
|
|
222 |
if ((epoch % 25) == 0) and (epoch != 0): |
|
|
223 |
lossplot_detailed(lossfile_train, lossfile_val, lossfile_geometry_train, lossfile_geometry_val, lossfile_KL_train, lossfile_KL_val, lossfile_compactness_train, lossfile_compactness_val, lossfile_PC_train, lossfile_PC_val, lossfile_ecg_train, lossfile_ecg_val, lossfile_RVp_train, lossfile_RVp_val, lossfile_size_train, lossfile_size_val) |
|
|
224 |
|
|
|
225 |
f_train = open(lossfile_train, 'a') # a: additional writing; w: overwrite writing |
|
|
226 |
f_val = open(lossfile_val, 'a') |
|
|
227 |
f_MI_train = open(lossfile_geometry_train, 'a') # a: additional writing; w: overwrite writing |
|
|
228 |
f_MI_val = open(lossfile_geometry_val, 'a') |
|
|
229 |
f_compactness_train = open(lossfile_compactness_train, 'a') # a: additional writing; w: overwrite writing |
|
|
230 |
f_compactness_val = open(lossfile_compactness_val, 'a') |
|
|
231 |
f_KL_train = open(lossfile_KL_train, 'a') # a: additional writing; w: overwrite writing |
|
|
232 |
f_KL_val = open(lossfile_KL_val, 'a') |
|
|
233 |
f_PC_train = open(lossfile_PC_train, 'a') # a: additional writing; w: overwrite writing |
|
|
234 |
f_PC_val = open(lossfile_PC_val, 'a') |
|
|
235 |
f_ecg_train = open(lossfile_ecg_train, 'a') # a: additional writing; w: overwrite writing |
|
|
236 |
f_ecg_val = open(lossfile_ecg_val, 'a') |
|
|
237 |
f_size_train = open(lossfile_size_train, 'a') # a: additional writing; w: overwrite writing |
|
|
238 |
f_size_val = open(lossfile_size_val, 'a') |
|
|
239 |
f_RVp_train = open(lossfile_RVp_train, 'a') # a: additional writing; w: overwrite writing |
|
|
240 |
f_RVp_val = open(lossfile_RVp_val, 'a') |
|
|
241 |
|
|
|
242 |
# if epoch != 0: |
|
|
243 |
# if lamda_KL < 1: |
|
|
244 |
# lamda_KL = 0.1*epoch*args.lamda_KL |
|
|
245 |
# else: |
|
|
246 |
# lamda_KL = 0.1 |
|
|
247 |
|
|
|
248 |
# training |
|
|
249 |
network.train() |
|
|
250 |
total_loss, iter_count = 0, 0 |
|
|
251 |
for i, data in enumerate(train_dataloader, 1): |
|
|
252 |
partial_input, ECG_input, gt_MI, partial_input_coarse, MI_type = data |
|
|
253 |
partial_input, ECG_input, gt_MI = partial_input.to(DEVICE), ECG_input.to(DEVICE), gt_MI.to(DEVICE) |
|
|
254 |
partial_input_coarse = partial_input_coarse.to(DEVICE) |
|
|
255 |
partial_input = partial_input.permute(0, 2, 1) |
|
|
256 |
|
|
|
257 |
optimizer.zero_grad() |
|
|
258 |
|
|
|
259 |
y_MI, y_coarse, y_detail, y_ECG, mu, log_var = network(partial_input[:, 0:7, :], ECG_input) |
|
|
260 |
|
|
|
261 |
loss_seg, loss_compactness, loss_MI_RVpenalty, loss_MI_size, KL_loss = calculate_inference_loss(y_MI, gt_MI, mu, log_var, partial_input) |
|
|
262 |
loss_geo, loss_signal = calculate_reconstruction_loss(y_coarse, y_detail, partial_input_coarse, partial_input, y_ECG, ECG_input) |
|
|
263 |
loss = loss_seg + args.lamda_compact*loss_compactness + args.lamda_RVp*loss_MI_RVpenalty + args.lamda_MIsize*loss_MI_size + args.lamda_KL*KL_loss + args.lamda_recon*loss_geo # + args.lamda_recon*loss_signal # |
|
|
264 |
|
|
|
265 |
check_grad = False |
|
|
266 |
if check_grad: |
|
|
267 |
print(loss.requires_grad) |
|
|
268 |
print(loss_seg.requires_grad) |
|
|
269 |
print(loss_compactness.requires_grad) |
|
|
270 |
print(loss_MI_RVpenalty.requires_grad) |
|
|
271 |
print(KL_loss.requires_grad) |
|
|
272 |
print(loss_MI_size.requires_grad) |
|
|
273 |
print(loss_geo.requires_grad) |
|
|
274 |
print(loss_signal.requires_grad) |
|
|
275 |
|
|
|
276 |
visual_check = False |
|
|
277 |
if visual_check: |
|
|
278 |
y_predict = y_MI[0].cpu().detach().numpy() |
|
|
279 |
y_gd = gt_MI[0].cpu().detach().numpy() |
|
|
280 |
x_input = partial_input[0].cpu().detach().numpy() |
|
|
281 |
y_predict_argmax = np.argmax(y_predict, axis=0) |
|
|
282 |
visualize_PC_with_twolabel(x_input[0:3, 0:args.num_input].transpose(), y_predict_argmax, y_gd, filename='RNmap_gd_pre.jpg') |
|
|
283 |
|
|
|
284 |
loss.backward() |
|
|
285 |
optimizer.step() |
|
|
286 |
|
|
|
287 |
f_train.write(str(loss.item())) |
|
|
288 |
f_train.write('\n') |
|
|
289 |
f_MI_train.write(str(loss_seg.item())) |
|
|
290 |
f_MI_train.write('\n') |
|
|
291 |
f_compactness_train.write(str(loss_compactness.item())) |
|
|
292 |
f_compactness_train.write('\n') |
|
|
293 |
f_KL_train.write(str(KL_loss.item())) |
|
|
294 |
f_KL_train.write('\n') |
|
|
295 |
f_PC_train.write(str(loss_geo.item())) |
|
|
296 |
f_PC_train.write('\n') |
|
|
297 |
f_ecg_train.write(str(loss_signal.item())) |
|
|
298 |
f_ecg_train.write('\n') |
|
|
299 |
f_size_train.write(str((loss_MI_size.item()))) |
|
|
300 |
f_size_train.write('\n') |
|
|
301 |
f_RVp_train.write(str(loss_MI_RVpenalty.item())) |
|
|
302 |
f_RVp_train.write('\n') |
|
|
303 |
|
|
|
304 |
iter_count += 1 |
|
|
305 |
total_loss += loss.item() |
|
|
306 |
|
|
|
307 |
if i % 50 == 0: |
|
|
308 |
print("Training epoch {}/{}, iteration {}/{}: loss is {}".format(epoch, args.epochs, i, max_iter, loss.item())) |
|
|
309 |
scheduler.step() |
|
|
310 |
|
|
|
311 |
print("\033[96mTraining epoch {}/{}: avg loss = {}\033[0m".format(epoch, args.epochs, total_loss / iter_count)) |
|
|
312 |
|
|
|
313 |
# evaluation |
|
|
314 |
network.eval() |
|
|
315 |
with torch.no_grad(): |
|
|
316 |
total_loss, iter_count = 0, 0 |
|
|
317 |
for i, data in enumerate(val_dataloader, 1): |
|
|
318 |
partial_input, ECG_input, gt_MI, partial_input_coarse, MI_type = data |
|
|
319 |
partial_input, ECG_input, gt_MI = partial_input.to(DEVICE), ECG_input.to(DEVICE), gt_MI.to(DEVICE) |
|
|
320 |
partial_input_coarse = partial_input_coarse.to(DEVICE) |
|
|
321 |
partial_input = partial_input.permute(0, 2, 1) |
|
|
322 |
|
|
|
323 |
y_MI, y_coarse, y_detail, y_ECG, mu, log_var = network(partial_input[:, 0:7, :], ECG_input) |
|
|
324 |
|
|
|
325 |
loss_seg, loss_compactness, loss_MI_RVpenalty, loss_MI_size, KL_loss = calculate_inference_loss(y_MI, gt_MI, mu, log_var, partial_input) |
|
|
326 |
loss_geo, loss_signal = calculate_reconstruction_loss(y_coarse, y_detail, partial_input_coarse, partial_input, y_ECG, ECG_input) |
|
|
327 |
loss = loss_seg + args.lamda_compact*loss_compactness + args.lamda_RVp*loss_MI_RVpenalty + args.lamda_MIsize*loss_MI_size + args.lamda_KL*KL_loss + args.lamda_recon*loss_geo # + args.lamda_recon*loss_signal # |
|
|
328 |
|
|
|
329 |
total_loss += loss.item() |
|
|
330 |
iter_count += 1 |
|
|
331 |
|
|
|
332 |
if ((epoch % 25) == 0) and (epoch != 0) and (i == 1): |
|
|
333 |
y_predict = y_MI[0].cpu().detach().numpy() |
|
|
334 |
y_gd = gt_MI[0].cpu().detach().numpy() |
|
|
335 |
x_input = partial_input[0].cpu().detach().numpy() |
|
|
336 |
y_predict_argmax = np.argmax(y_predict, axis=0) |
|
|
337 |
visualize_PC_with_twolabel(x_input[0:3, 0:args.num_input].transpose(), y_predict_argmax, y_gd, filename='RNmap_gd_pre.jpg') |
|
|
338 |
|
|
|
339 |
f_val.write(str(loss.item())) |
|
|
340 |
f_val.write('\n') |
|
|
341 |
f_MI_val.write(str(loss_seg.item())) |
|
|
342 |
f_MI_val.write('\n') |
|
|
343 |
f_compactness_val.write(str(loss_compactness.item())) |
|
|
344 |
f_compactness_val.write('\n') |
|
|
345 |
f_KL_val.write(str(KL_loss.item())) |
|
|
346 |
f_KL_val.write('\n') |
|
|
347 |
f_PC_val.write(str(loss_geo.item())) |
|
|
348 |
f_PC_val.write('\n') |
|
|
349 |
f_ecg_val.write(str(loss_signal.item())) |
|
|
350 |
f_ecg_val.write('\n') |
|
|
351 |
f_size_val.write(str(loss_MI_size.item())) |
|
|
352 |
f_size_val.write('\n') |
|
|
353 |
f_RVp_val.write(str(loss_MI_RVpenalty.item())) |
|
|
354 |
f_RVp_val.write('\n') |
|
|
355 |
|
|
|
356 |
mean_loss = total_loss / iter_count |
|
|
357 |
print("\033[35mValidation epoch {}/{}, loss is {}\033[0m".format(epoch, args.epochs, mean_loss)) |
|
|
358 |
|
|
|
359 |
# records the best model and epoch |
|
|
360 |
if mean_loss < minimum_loss: |
|
|
361 |
best_epoch = epoch |
|
|
362 |
minimum_loss = mean_loss |
|
|
363 |
strNetSaveName = 'net_model.pkl' |
|
|
364 |
# strNetSaveName = 'net_with_%d.pkl' % epoch |
|
|
365 |
torch.save(network.state_dict(), args.log_dir + '/' + strNetSaveName) |
|
|
366 |
|
|
|
367 |
print("\033[4;37mBest model (lowest loss) in epoch {}\033[0m".format(best_epoch)) |
|
|
368 |
|
|
|
369 |
lossplot(lossfile_train, lossfile_val) |
|
|
370 |
|
|
|
371 |
if __name__ == "__main__": |
|
|
372 |
parser = argparse.ArgumentParser() |
|
|
373 |
parser.add_argument('--partial_root', type=str, default='./Big_data_inference/meta_data/UKB_clinical_data/') |
|
|
374 |
parser.add_argument('--model', type=str, default=None) #'log/net_model.pkl' |
|
|
375 |
parser.add_argument('--in_ch', type=int, default=3+4) # coordinate dimension + label index |
|
|
376 |
parser.add_argument('--out_ch', type=int, default=3) # 3scar, BZ, normal/ 18 for ecg-based classification |
|
|
377 |
parser.add_argument('--z_dims', type=int, default=16) |
|
|
378 |
parser.add_argument('--num_input', type=int, default=1024*4) |
|
|
379 |
parser.add_argument('--batch_size', type=int, default=4) # 4 |
|
|
380 |
parser.add_argument('--lamda_recon', type=float, default=1) # 1 |
|
|
381 |
parser.add_argument('--lamda_KL', type=float, default=1e-2) # 1e-2 |
|
|
382 |
parser.add_argument('--lamda_MIsize', type=float, default=1) # 1 |
|
|
383 |
parser.add_argument('--lamda_RVp', type=float, default=1) # 1 |
|
|
384 |
parser.add_argument('--lamda_compact', type=float, default=1) # 1 |
|
|
385 |
parser.add_argument('--base_lr', type=float, default=1e-4) #1e-4 |
|
|
386 |
parser.add_argument('--lr_decay_steps', type=int, default=50) |
|
|
387 |
parser.add_argument('--lr_decay_rate', type=float, default=0.5) |
|
|
388 |
parser.add_argument('--weight_decay', type=float, default=1e-3) #1e-3 |
|
|
389 |
parser.add_argument('--epochs', type=int, default=500) |
|
|
390 |
parser.add_argument('--num_workers', type=int, default=1) |
|
|
391 |
parser.add_argument('--log_dir', type=str, default='log') |
|
|
392 |
args = parser.parse_args() |
|
|
393 |
|
|
|
394 |
train(args) |