|
a |
|
b/train.py |
|
|
1 |
import sys |
|
|
2 |
sys.path.append('architectures/deeplab_3D/') |
|
|
3 |
sys.path.append('architectures/unet_3D/') |
|
|
4 |
sys.path.append('architectures/hrnet_3D/') |
|
|
5 |
sys.path.append('architectures/experiment_nets_3D/') |
|
|
6 |
sys.path.append('utils/') |
|
|
7 |
|
|
|
8 |
import torch |
|
|
9 |
import torch.nn as nn |
|
|
10 |
from torch.autograd import Variable |
|
|
11 |
import torch.backends.cudnn as cudnn |
|
|
12 |
import torch.nn.functional as F |
|
|
13 |
import torch.optim as optim |
|
|
14 |
|
|
|
15 |
import numpy as np |
|
|
16 |
import scipy.misc |
|
|
17 |
import os |
|
|
18 |
from tqdm import * |
|
|
19 |
import random |
|
|
20 |
from random import randint |
|
|
21 |
from docopt import docopt |
|
|
22 |
|
|
|
23 |
import deeplab_resnet_3D |
|
|
24 |
import unet_3D |
|
|
25 |
import highresnet_3D |
|
|
26 |
import exp_net_3D |
|
|
27 |
|
|
|
28 |
import lossF |
|
|
29 |
import PP |
|
|
30 |
import augmentations as AUG |
|
|
31 |
|
|
|
32 |
import nibabel as nib |
|
|
33 |
import evalF as EF |
|
|
34 |
import evalFP as EFP |
|
|
35 |
|
|
|
36 |
|
|
|
37 |
docstr = """Write something here |
|
|
38 |
|
|
|
39 |
Usage: |
|
|
40 |
train.py [options] |
|
|
41 |
|
|
|
42 |
Options: |
|
|
43 |
-h, --help Print this message |
|
|
44 |
--archId=<int> Architecture to run, 0 is DeepLab 3D, 1 is U-net3D, 2 is HRNet [default: 2] |
|
|
45 |
--trainMethod=<int> 0 is full image, 1 is by patches (random), 2 is by patches (center pixel) [default: 1] |
|
|
46 |
--lossFunction=<str> Loss function name. 'dice' option available [default: dice] |
|
|
47 |
--imgSize=<str> Image size [default: 200x200x100] |
|
|
48 |
--mainFolderPath=<str> Main folder path [default: ../Data/MS2017b/] |
|
|
49 |
--patchSize=<int> Size of the patch [default: 60] |
|
|
50 |
--patchSizeStage0=<int> Size of the patch at stage 0 [default: 41] |
|
|
51 |
--namePostfix=<str> Postfix of flair. i.e. to use FLAIR_s postfix is _s. This also determines the train file [default: _200x200x100orig] |
|
|
52 |
--modelPath=<str> Path of model to continue training on [default: none] |
|
|
53 |
--NoLabels=<int> The number of different labels in training data, including background [default: 2] |
|
|
54 |
--maxIter=<int> Maximum number of iterations [default: 20000] |
|
|
55 |
--maxIterStage0=<int> Maximum number of iterations for stage 0 training [default: -1] |
|
|
56 |
-i, --iterSize=<int> Num iters to accumulate gradients over [default: 1] |
|
|
57 |
--lr=<float> Learning Rate [default: 0.0001] |
|
|
58 |
--gpu0=<int> GPU number [default: 0] |
|
|
59 |
--useGPU=<int> Use GPU or not [default: 0] |
|
|
60 |
--experiment=<str> Specify experiment instead to run. e.g. 1x1x1x1x1x1_1_0 means 1 dilations all 6 blocks, with priv, no ASPP [default: None] |
|
|
61 |
""" |
|
|
62 |
args = docopt(docstr, version='v0.1') |
|
|
63 |
print(args) |
|
|
64 |
|
|
|
65 |
arch_id = int(args['--archId']) |
|
|
66 |
train_method = int(args['--trainMethod']) |
|
|
67 |
loss_name = args['--lossFunction'] |
|
|
68 |
img_dims = np.array(args['--imgSize'].split('x'), dtype=np.int64) |
|
|
69 |
main_folder_path = args['--mainFolderPath'] |
|
|
70 |
patch_size = int(args['--patchSize']) |
|
|
71 |
|
|
|
72 |
postfix = args['--namePostfix'] |
|
|
73 |
model_path = args['--modelPath'] |
|
|
74 |
num_labels = int(args['--NoLabels']) |
|
|
75 |
max_iter = int(args['--maxIter']) |
|
|
76 |
|
|
|
77 |
iter_size = int(args['--iterSize']) |
|
|
78 |
base_lr = float(args['--lr']) |
|
|
79 |
experiment = str(args['--experiment']) |
|
|
80 |
gpu0 = int(args['--gpu0']) |
|
|
81 |
useGPU = int(args['--useGPU']) |
|
|
82 |
batch_size = 1 |
|
|
83 |
#img_dims = [197, 233, 189] |
|
|
84 |
list_path = main_folder_path + 'train' + postfix + '.txt' |
|
|
85 |
print('READING from ', list_path) |
|
|
86 |
img_type_path = 'pre/FLAIR' + postfix + '.nii.gz' |
|
|
87 |
gt_type_path = 'wmh' + postfix + '.nii.gz' |
|
|
88 |
|
|
|
89 |
|
|
|
90 |
patch_size_stage0 = int(args['--patchSizeStage0']) |
|
|
91 |
max_iter_stage0 = int(args['--maxIterStage0']) |
|
|
92 |
|
|
|
93 |
iter_low = 1 |
|
|
94 |
iter_high = max_iter + 1 |
|
|
95 |
|
|
|
96 |
if model_path != 'none': |
|
|
97 |
iter_low = int(model_path.split('iter_')[-1].replace('.pth','')) + 1 |
|
|
98 |
if iter_low >= iter_high: |
|
|
99 |
print('Model already at ' + str(iter_low) + ' iterations. Change max iter size') |
|
|
100 |
sys.exit() |
|
|
101 |
|
|
|
102 |
num_labels2 = 209 |
|
|
103 |
#change to 0 to enable stage 0 patch learning |
|
|
104 |
|
|
|
105 |
if num_labels == 2: |
|
|
106 |
onlyLesions = True |
|
|
107 |
else: |
|
|
108 |
onlyLesions = False |
|
|
109 |
|
|
|
110 |
if useGPU: |
|
|
111 |
cudnn.enabled = True |
|
|
112 |
else: |
|
|
113 |
cudnn.enabled = False |
|
|
114 |
|
|
|
115 |
if experiment != 'None': |
|
|
116 |
snapshot_prefix = 'EXP3D' + '_' + experiment + '_' + loss_name + '_' + str(train_method) |
|
|
117 |
else: |
|
|
118 |
if arch_id == 0: |
|
|
119 |
snapshot_prefix = 'DL3D_' + loss_name + '_' + str(train_method) + '_' + PP.getTime() |
|
|
120 |
elif arch_id == 1: |
|
|
121 |
snapshot_prefix = 'UNET3D_' + loss_name + '_' + str(train_method) + '_' + PP.getTime() |
|
|
122 |
elif arch_id == 2: |
|
|
123 |
snapshot_prefix = 'HR3D' + loss_name + '_' + str(train_method) + '_' + PP.getTime() |
|
|
124 |
to_center_pixel = False |
|
|
125 |
center_pixel_folder_path, locs_lesion, locs_other = (None, None, None) |
|
|
126 |
if train_method == 2: |
|
|
127 |
to_center_pixel = True |
|
|
128 |
if not os.path.exists(os.path.join(main_folder_path, 'centerPixelPatches' + postfix + '_' + str(patch_size))): |
|
|
129 |
print('Pixel patch folder does not exist') |
|
|
130 |
sys.exit() |
|
|
131 |
#load few files |
|
|
132 |
img_list = PP.read_file(list_path) |
|
|
133 |
|
|
|
134 |
results_folder = 'train_results/' |
|
|
135 |
log_file_path = os.path.join(results_folder, 'logs', snapshot_prefix + '_log.txt') |
|
|
136 |
model_file_path = os.path.join(results_folder, 'models', snapshot_prefix + '_best.pth') |
|
|
137 |
|
|
|
138 |
logfile = open(log_file_path, 'w+') |
|
|
139 |
info_run = "arch ID: {:d} | max iters: {:10d} | max iters stage 0 : {:10d} | train method : {} | lr : {}".format(arch_id, max_iter, max_iter_stage0, train_method, base_lr) |
|
|
140 |
logfile.write(info_run + '\n') |
|
|
141 |
logfile.flush() |
|
|
142 |
|
|
|
143 |
def lr_poly(base_lr, iter,max_iter,power): |
|
|
144 |
return base_lr*((1-float(iter)/max_iter)**(power)) |
|
|
145 |
|
|
|
146 |
def modelInit(): |
|
|
147 |
isPriv = False |
|
|
148 |
if arch_id > 10: |
|
|
149 |
isPriv = True |
|
|
150 |
|
|
|
151 |
if experiment != 'None': |
|
|
152 |
dilation_arr, isPriv, withASPP = PP.getExperimentInfo(experiment) |
|
|
153 |
model = exp_net_3D.getExpNet(num_labels, dilation_arr, isPriv, NoLabels2 = num_labels2, withASPP = withASPP) |
|
|
154 |
elif arch_id == 0: |
|
|
155 |
model = deeplab_resnet_3D.Res_Deeplab(num_labels) |
|
|
156 |
elif arch_id == 1: |
|
|
157 |
model = unet_3D.UNet3D(1, num_labels) |
|
|
158 |
elif arch_id == 2: |
|
|
159 |
model = highresnet_3D.getHRNet(num_labels) |
|
|
160 |
|
|
|
161 |
if model_path != 'none': |
|
|
162 |
if useGPU: |
|
|
163 |
#loading on GPU when model was saved on GPU |
|
|
164 |
saved_state_dict = torch.load(model_path) |
|
|
165 |
else: |
|
|
166 |
#loading on CPU when model was saved on GPU |
|
|
167 |
saved_state_dict = torch.load(model_path, map_location=lambda storage, loc: storage) |
|
|
168 |
model.load_state_dict(saved_state_dict) |
|
|
169 |
|
|
|
170 |
model.float() |
|
|
171 |
model.eval() # use_global_stats = True |
|
|
172 |
return model, isPriv |
|
|
173 |
|
|
|
174 |
def trainModel(model): |
|
|
175 |
if useGPU: |
|
|
176 |
model.cuda(gpu0) |
|
|
177 |
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = base_lr) |
|
|
178 |
|
|
|
179 |
optimizer.zero_grad() |
|
|
180 |
print(model) |
|
|
181 |
curr_val = 0 |
|
|
182 |
best_val = 0 |
|
|
183 |
val_change = False |
|
|
184 |
loss_arr = np.zeros([iter_size]) |
|
|
185 |
loss_arr_i = 0 |
|
|
186 |
stage = 0 |
|
|
187 |
print('---------------') |
|
|
188 |
print('STAGE ' + str(stage)) |
|
|
189 |
print('---------------') |
|
|
190 |
|
|
|
191 |
for iter in range(iter_low, iter_high): |
|
|
192 |
if iter > max_iter_stage0 and stage != 1: |
|
|
193 |
print('---------------') |
|
|
194 |
print('Stage 1') |
|
|
195 |
print('---------------') |
|
|
196 |
stage = 1 |
|
|
197 |
|
|
|
198 |
if train_method == 0: |
|
|
199 |
img_b, label_b, _ = PP.extractImgBatch(batch_size, img_list, img_dims, onlyLesions, |
|
|
200 |
main_folder_path = '../Data/MS2017b/') |
|
|
201 |
elif train_method == 1 or train_method == 2: |
|
|
202 |
if stage == 0: |
|
|
203 |
batch_size = 1 |
|
|
204 |
img_b, label_b, _ = PP.extractPatchBatch(batch_size, patch_size_stage0, img_list, onlyLesions, center_pixel = to_center_pixel, main_folder_path = '../Data/MS2017b/', postfix=postfix) |
|
|
205 |
else: |
|
|
206 |
batch_size = 1 |
|
|
207 |
img_b, label_b, _ = PP.extractPatchBatch(batch_size, patch_size, img_list, onlyLesions, center_pixel = to_center_pixel, main_folder_path = '../Data/MS2017b/', postfix=postfix) |
|
|
208 |
else: |
|
|
209 |
print('Invalid training method format') |
|
|
210 |
sys.exit() |
|
|
211 |
|
|
|
212 |
if stage == 0: |
|
|
213 |
img_b, label_b = AUG.augmentPatchLossLess([img_b, label_b]) |
|
|
214 |
img_b, label_b = AUG.augmentPatchLossy([img_b, label_b]) |
|
|
215 |
#img_b, label_b = AUG.augmentPatchLossless(img_b, label_b) |
|
|
216 |
#img_b is of shape (batch_num) x 1 x dim1 x dim2 x dim3 |
|
|
217 |
#label_b is of shape (batch_num) x 1 x dim1 x dim2 x dim3 |
|
|
218 |
#batch_num should be 1 since too memory intensive |
|
|
219 |
|
|
|
220 |
label_b = label_b.astype(np.int64) |
|
|
221 |
#convert label from (batch_num x 1 x dim1 x dim2 x dim3) |
|
|
222 |
# to ((batch_numxdim1*dim2*dim3) x 3) (one hot) |
|
|
223 |
temp = label_b.reshape([-1]) |
|
|
224 |
label_b = np.zeros([temp.size, num_labels]) |
|
|
225 |
label_b[np.arange(temp.size),temp] = 1 |
|
|
226 |
label_b = torch.from_numpy(label_b).float() |
|
|
227 |
|
|
|
228 |
imgs = torch.from_numpy(img_b).float() |
|
|
229 |
|
|
|
230 |
if useGPU: |
|
|
231 |
imgs, label_b = Variable(imgs).cuda(gpu0), Variable(label_b).cuda(gpu0) |
|
|
232 |
else: |
|
|
233 |
imgs, label_b = Variable(imgs), Variable(label_b) |
|
|
234 |
|
|
|
235 |
#--------------------------------------------- |
|
|
236 |
#out size is (1, 3, dim1, dim2, dim3) |
|
|
237 |
#--------------------------------------------- |
|
|
238 |
out = model(imgs) |
|
|
239 |
out = out.permute(0,2,3,4,1).contiguous() |
|
|
240 |
out = out.view(-1, num_labels) |
|
|
241 |
#--------------------------------------------- |
|
|
242 |
#out size is (1 * dim1 * dim2 * dim3, 3) |
|
|
243 |
#--------------------------------------------- |
|
|
244 |
|
|
|
245 |
#loss function |
|
|
246 |
m = nn.Softmax() |
|
|
247 |
loss = lossF.simple_dice_loss3D(m(out), label_b) |
|
|
248 |
|
|
|
249 |
loss /= iter_size |
|
|
250 |
loss.backward() |
|
|
251 |
|
|
|
252 |
loss_val = loss.data.cpu().numpy() |
|
|
253 |
loss_arr[loss_arr_i] = loss_val |
|
|
254 |
loss_arr_i = (loss_arr_i + 1) % iter_size |
|
|
255 |
|
|
|
256 |
if iter % 1 == 0: |
|
|
257 |
if val_change: |
|
|
258 |
print "iter = {:6d}/{:6d} Loss: {:1.6f} Val Score: {:1.6f} \r".format(iter-1, max_iter, float(loss_val)*iter_size, curr_val), |
|
|
259 |
sys.stdout.flush() |
|
|
260 |
print "" |
|
|
261 |
val_change = False |
|
|
262 |
print "iter = {:6d}/{:6d} Loss: {:1.6f} Val Score: {:1.6f} \r".format(iter, max_iter, float(loss_val)*iter_size, curr_val), |
|
|
263 |
sys.stdout.flush() |
|
|
264 |
if iter % 1000 == 0: |
|
|
265 |
val_change = True |
|
|
266 |
curr_val = EF.evalModelX(model, num_labels, postfix, main_folder_path, (train_method != 0), gpu0, useGPU, eval_metric = 'iou', patch_size = patch_size, extra_patch = 5) |
|
|
267 |
if curr_val > best_val: |
|
|
268 |
best_val = curr_val |
|
|
269 |
print('\nSaving better model...') |
|
|
270 |
torch.save(model.state_dict(), model_file_path) |
|
|
271 |
logfile.write("iter = {:6d}/{:6d} Loss: {:1.6f} Val Score: {:1.6f} \n".format(iter, max_iter, np.sum(loss_arr), curr_val)) |
|
|
272 |
logfile.flush() |
|
|
273 |
if iter % iter_size == 0: |
|
|
274 |
optimizer.step() |
|
|
275 |
optimizer.zero_grad() |
|
|
276 |
|
|
|
277 |
del out, loss |
|
|
278 |
|
|
|
279 |
def setupGIFVar(gif_b): |
|
|
280 |
gif_b = gif_b.astype(np.int64) |
|
|
281 |
gif_b = gif_b.reshape([-1]) |
|
|
282 |
gif_b = torch.from_numpy(gif_b).long() |
|
|
283 |
|
|
|
284 |
if useGPU: |
|
|
285 |
gif_b = Variable(gif_b).cuda(gpu0) |
|
|
286 |
else: |
|
|
287 |
gif_b = Variable(gif_b) |
|
|
288 |
return gif_b |
|
|
289 |
|
|
|
290 |
def trainModelPriv(model): |
|
|
291 |
if useGPU: |
|
|
292 |
model.cuda(gpu0) |
|
|
293 |
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = base_lr) |
|
|
294 |
optimizer.zero_grad() |
|
|
295 |
print(model) |
|
|
296 |
curr_val1 = 0 |
|
|
297 |
curr_val2 = 0 |
|
|
298 |
best_val2 = 0 |
|
|
299 |
val_change = False |
|
|
300 |
loss_arr1 = np.zeros([iter_size]) |
|
|
301 |
loss_arr2 = np.zeros([iter_size]) |
|
|
302 |
loss_arr_i = 0 |
|
|
303 |
|
|
|
304 |
stage = 0 |
|
|
305 |
print('---------------') |
|
|
306 |
print('STAGE ' + str(stage)) |
|
|
307 |
print('---------------') |
|
|
308 |
|
|
|
309 |
for iter in range(iter_low, iter_high): |
|
|
310 |
if iter > max_iter_stage0 and stage != 1: |
|
|
311 |
print('---------------') |
|
|
312 |
print('Stage 1') |
|
|
313 |
print('---------------') |
|
|
314 |
stage = 1 |
|
|
315 |
|
|
|
316 |
if train_method == 0: |
|
|
317 |
img_b, label_b, gif_b = PP.extractImgBatch(batch_size, img_list, img_dims, onlyLesions, |
|
|
318 |
main_folder_path = '../Data/MS2017b/', with_priv = True) |
|
|
319 |
elif train_method == 1 or train_method == 2: |
|
|
320 |
if stage == 0: |
|
|
321 |
batch_size = 5 |
|
|
322 |
img_b, label_b, gif_b = PP.extractPatchBatch(batch_size, patch_size_stage0, img_list, onlyLesions, |
|
|
323 |
center_pixel = to_center_pixel, |
|
|
324 |
main_folder_path = '../Data/MS2017b/', |
|
|
325 |
postfix=postfix, with_priv= True) |
|
|
326 |
else: |
|
|
327 |
batch_size = 1 |
|
|
328 |
img_b, label_b, gif_b = PP.extractPatchBatch(batch_size, patch_size, img_list, onlyLesions, |
|
|
329 |
center_pixel = to_center_pixel, |
|
|
330 |
main_folder_path = '../Data/MS2017b/', |
|
|
331 |
postfix=postfix, with_priv= True) |
|
|
332 |
else: |
|
|
333 |
print('Invalid training method format') |
|
|
334 |
sys.exit() |
|
|
335 |
|
|
|
336 |
img_b, label_b, gif_b = AUG.augmentPatchLossy([img_b, label_b, gif_b]) |
|
|
337 |
|
|
|
338 |
#img_b is of shape (batch_num) x 1 x dim1 x dim2 x dim3 |
|
|
339 |
#label_b is of shape (batch_num) x 1 x dim1 x dim2 x dim3 |
|
|
340 |
|
|
|
341 |
label_b = label_b.astype(np.int64) |
|
|
342 |
|
|
|
343 |
#convert label from (batch_num x 1 x dim1 x dim2 x dim3) |
|
|
344 |
# to ((batch_numxdim1*dim2*dim3) x 3) (one hot) |
|
|
345 |
temp = label_b.reshape([-1]) |
|
|
346 |
label_b = np.zeros([temp.size, num_labels]) |
|
|
347 |
label_b[np.arange(temp.size),temp] = 1 |
|
|
348 |
label_b = torch.from_numpy(label_b).float() |
|
|
349 |
|
|
|
350 |
imgs = torch.from_numpy(img_b).float() |
|
|
351 |
|
|
|
352 |
if useGPU: |
|
|
353 |
imgs, label_b = Variable(imgs).cuda(gpu0), Variable(label_b).cuda(gpu0) |
|
|
354 |
else: |
|
|
355 |
imgs, label_b = Variable(imgs), Variable(label_b) |
|
|
356 |
|
|
|
357 |
gif_b = setupGIFVar(gif_b) |
|
|
358 |
|
|
|
359 |
#--------------------------------------------- |
|
|
360 |
#out size is (1, 3, dim1, dim2, dim3) |
|
|
361 |
#--------------------------------------------- |
|
|
362 |
#out1 is extra info |
|
|
363 |
out1, out2 = model(imgs) |
|
|
364 |
|
|
|
365 |
out1 = out1.permute(0,2,3,4,1).contiguous() |
|
|
366 |
out1 = out1.view(-1, num_labels2) |
|
|
367 |
|
|
|
368 |
out2 = out2.permute(0,2,3,4,1).contiguous() |
|
|
369 |
out2 = out2.view(-1, num_labels) |
|
|
370 |
#--------------------------------------------- |
|
|
371 |
#out size is (1 * dim1 * dim2 * dim3, 3) |
|
|
372 |
#--------------------------------------------- |
|
|
373 |
m2 = nn.Softmax() |
|
|
374 |
loss2 = lossF.simple_dice_loss3D(m2(out2), label_b) |
|
|
375 |
m1 = nn.LogSoftmax() |
|
|
376 |
loss1 = F.nll_loss(m1(out1), gif_b) |
|
|
377 |
|
|
|
378 |
loss1 /= iter_size |
|
|
379 |
loss2 /= iter_size |
|
|
380 |
|
|
|
381 |
torch.autograd.backward([loss1, loss2]) |
|
|
382 |
|
|
|
383 |
loss_val1 = float(loss1.data.cpu().numpy()) |
|
|
384 |
loss_arr1[loss_arr_i] = loss_val1 |
|
|
385 |
|
|
|
386 |
loss_val2 = float(loss2.data.cpu().numpy()) |
|
|
387 |
loss_arr2[loss_arr_i] = loss_val2 |
|
|
388 |
|
|
|
389 |
loss_arr_i = (loss_arr_i + 1) % iter_size |
|
|
390 |
|
|
|
391 |
if iter % 1 == 0: |
|
|
392 |
if val_change: |
|
|
393 |
print "iter = {:6d}/{:6d} Loss_main: {:1.6f} Loss_secondary: {:1.6f} Val Score: {:1.6f} Val Score secondary: {:1.6f} \r".format(iter-1, max_iter, loss_val2*iter_size, loss_val1*iter_size, curr_val2, curr_val1), |
|
|
394 |
sys.stdout.flush() |
|
|
395 |
print "" |
|
|
396 |
val_change = False |
|
|
397 |
print "iter = {:6d}/{:6d} Loss_main: {:1.6f} Loss_secondary: {:1.6f} Val Score main: {:1.6f} Val Score secondary: {:1.6f} \r".format(iter, max_iter, loss_val2*iter_size, loss_val1*iter_size, curr_val2, curr_val1), |
|
|
398 |
sys.stdout.flush() |
|
|
399 |
if iter % 2000 == 0: |
|
|
400 |
val_change = True |
|
|
401 |
curr_val1, curr_val2 = EFP.evalModelX(model, num_labels, num_labels2, postfix, main_folder_path, (train_method != 0), gpu0, useGPU, eval_metric = 'iou', patch_size = patch_size, extra_patch = 5, priv_eval = True) |
|
|
402 |
if curr_val2 > best_val2: |
|
|
403 |
best_val2 = curr_val2 |
|
|
404 |
torch.save(model.state_dict(), model_file_path) |
|
|
405 |
print('\nSaving better model...') |
|
|
406 |
logfile.write("iter = {:6d}/{:6d} Loss_main: {:1.6f} Loss_secondary: {:1.6f} Val Score main: {:1.6f} Val Score secondary: {:1.6f} \n".format(iter, max_iter, np.sum(loss_arr2), np.sum(loss_arr1), curr_val2, curr_val1)) |
|
|
407 |
logfile.flush() |
|
|
408 |
if iter % iter_size == 0: |
|
|
409 |
optimizer.step() |
|
|
410 |
optimizer.zero_grad() |
|
|
411 |
|
|
|
412 |
del out1, out2, loss1, loss2 |
|
|
413 |
|
|
|
414 |
if __name__ == "__main__": |
|
|
415 |
model, with_priv = modelInit() |
|
|
416 |
if with_priv: |
|
|
417 |
trainModelPriv(model) |
|
|
418 |
else: |
|
|
419 |
trainModel(model) |
|
|
420 |
logfile.close() |