|
a |
|
b/neusomatic/python/train.py |
|
|
1 |
#------------------------------------------------------------------------- |
|
|
2 |
# train.py |
|
|
3 |
# Train NeuSomatic network |
|
|
4 |
#------------------------------------------------------------------------- |
|
|
5 |
|
|
|
6 |
import os |
|
|
7 |
import traceback |
|
|
8 |
import argparse |
|
|
9 |
import datetime |
|
|
10 |
import logging |
|
|
11 |
|
|
|
12 |
import numpy as np |
|
|
13 |
import torch |
|
|
14 |
from torch.autograd import Variable |
|
|
15 |
import torch.nn as nn |
|
|
16 |
import torch.nn.functional as F |
|
|
17 |
import torch.optim as optim |
|
|
18 |
from torchvision import transforms |
|
|
19 |
import torchvision |
|
|
20 |
from random import shuffle |
|
|
21 |
import pickle |
|
|
22 |
|
|
|
23 |
from network import NeuSomaticNet |
|
|
24 |
from dataloader import NeuSomaticDataset, matrix_transform |
|
|
25 |
from merge_tsvs import merge_tsvs |
|
|
26 |
|
|
|
27 |
type_class_dict = {"DEL": 0, "INS": 1, "NONE": 2, "SNP": 3} |
|
|
28 |
vartype_classes = ['DEL', 'INS', 'NONE', 'SNP'] |
|
|
29 |
|
|
|
30 |
import torch._utils |
|
|
31 |
try: |
|
|
32 |
torch._utils._rebuild_tensor_v2 |
|
|
33 |
except AttributeError: |
|
|
34 |
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks): |
|
|
35 |
tensor = torch._utils._rebuild_tensor( |
|
|
36 |
storage, storage_offset, size, stride) |
|
|
37 |
tensor.requires_grad = requires_grad |
|
|
38 |
tensor._backward_hooks = backward_hooks |
|
|
39 |
return tensor |
|
|
40 |
torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2 |
|
|
41 |
|
|
|
42 |
|
|
|
43 |
def make_weights_for_balanced_classes(count_class_t, count_class_l, nclasses_t, nclasses_l, |
|
|
44 |
none_count=None): |
|
|
45 |
logger = logging.getLogger(make_weights_for_balanced_classes.__name__) |
|
|
46 |
|
|
|
47 |
w_t = [0] * nclasses_t |
|
|
48 |
w_l = [0] * nclasses_l |
|
|
49 |
|
|
|
50 |
count_class_t = list(count_class_t) |
|
|
51 |
count_class_l = list(count_class_l) |
|
|
52 |
if none_count: |
|
|
53 |
count_class_t[type_class_dict["NONE"]] = none_count |
|
|
54 |
count_class_l[0] = none_count |
|
|
55 |
|
|
|
56 |
logger.info("count type classes: {}".format( |
|
|
57 |
list(zip(vartype_classes, count_class_t)))) |
|
|
58 |
N = float(sum(count_class_t)) |
|
|
59 |
for i in range(nclasses_t): |
|
|
60 |
w_t[i] = (1 - (float(count_class_t[i]) / float(N))) / float(nclasses_t) |
|
|
61 |
w_t = np.array(w_t) |
|
|
62 |
logger.info("weight type classes: {}".format( |
|
|
63 |
list(zip(vartype_classes, w_t)))) |
|
|
64 |
|
|
|
65 |
logger.info("count length classes: {}".format(list( |
|
|
66 |
zip(range(nclasses_l), count_class_l)))) |
|
|
67 |
N = float(sum(count_class_l)) |
|
|
68 |
for i in range(nclasses_l): |
|
|
69 |
w_l[i] = (1 - (float(count_class_l[i]) / float(N))) / float(nclasses_l) |
|
|
70 |
w_l = np.array(w_l) |
|
|
71 |
logger.info("weight length classes: {}".format(list( |
|
|
72 |
zip(range(nclasses_l), w_l)))) |
|
|
73 |
return w_t, w_l |
|
|
74 |
|
|
|
75 |
|
|
|
76 |
def test(net, epoch, validation_loader, use_cuda): |
|
|
77 |
logger = logging.getLogger(test.__name__) |
|
|
78 |
net.eval() |
|
|
79 |
nclasses = len(vartype_classes) |
|
|
80 |
class_correct = list(0. for i in range(nclasses)) |
|
|
81 |
class_total = list(0. for i in range(nclasses)) |
|
|
82 |
class_p_total = list(0. for i in range(nclasses)) |
|
|
83 |
|
|
|
84 |
len_class_correct = list(0. for i in range(4)) |
|
|
85 |
len_class_total = list(0. for i in range(4)) |
|
|
86 |
len_class_p_total = list(0. for i in range(4)) |
|
|
87 |
|
|
|
88 |
falses = [] |
|
|
89 |
for data in validation_loader: |
|
|
90 |
(matrices, labels, _, var_len_s, _), (paths) = data |
|
|
91 |
|
|
|
92 |
matrices = Variable(matrices) |
|
|
93 |
if use_cuda: |
|
|
94 |
matrices = matrices.cuda() |
|
|
95 |
|
|
|
96 |
outputs, _ = net(matrices) |
|
|
97 |
[outputs1, outputs2, outputs3] = outputs |
|
|
98 |
|
|
|
99 |
_, predicted = torch.max(outputs1.data.cpu(), 1) |
|
|
100 |
pos_pred = outputs2.data.cpu().numpy() |
|
|
101 |
_, len_pred = torch.max(outputs3.data.cpu(), 1) |
|
|
102 |
preds = {} |
|
|
103 |
for i, _ in enumerate(paths[0]): |
|
|
104 |
preds[i] = [vartype_classes[predicted[i]], pos_pred[i], len_pred[i]] |
|
|
105 |
|
|
|
106 |
if labels.size()[0] > 1: |
|
|
107 |
compare_labels = (predicted == labels).squeeze() |
|
|
108 |
else: |
|
|
109 |
compare_labels = (predicted == labels) |
|
|
110 |
false_preds = np.where(compare_labels.numpy() == 0)[0] |
|
|
111 |
if len(false_preds) > 0: |
|
|
112 |
for i in false_preds: |
|
|
113 |
falses.append([paths[0][i], vartype_classes[predicted[i]], pos_pred[i], len_pred[i], |
|
|
114 |
list( |
|
|
115 |
np.round(F.softmax(outputs1[i, :], 0).data.cpu().numpy(), 4)), |
|
|
116 |
list( |
|
|
117 |
np.round(F.softmax(outputs3[i, :], 0).data.cpu().numpy(), 4))]) |
|
|
118 |
|
|
|
119 |
for i in range(len(labels)): |
|
|
120 |
label = labels[i] |
|
|
121 |
class_correct[label] += compare_labels[i].data.cpu().numpy() |
|
|
122 |
class_total[label] += 1 |
|
|
123 |
for i in range(len(predicted)): |
|
|
124 |
label = predicted[i] |
|
|
125 |
class_p_total[label] += 1 |
|
|
126 |
|
|
|
127 |
if var_len_s.size()[0] > 1: |
|
|
128 |
compare_len = (len_pred == var_len_s).squeeze() |
|
|
129 |
else: |
|
|
130 |
compare_len = (len_pred == var_len_s) |
|
|
131 |
|
|
|
132 |
for i in range(len(var_len_s)): |
|
|
133 |
len_ = var_len_s[i] |
|
|
134 |
len_class_correct[len_] += compare_len[i].data.cpu().numpy() |
|
|
135 |
len_class_total[len_] += 1 |
|
|
136 |
for i in range(len(len_pred)): |
|
|
137 |
len_ = len_pred[i] |
|
|
138 |
len_class_p_total[len_] += 1 |
|
|
139 |
|
|
|
140 |
for i in range(nclasses): |
|
|
141 |
SN = 100 * class_correct[i] / (class_total[i] + 0.0001) |
|
|
142 |
PR = 100 * class_correct[i] / (class_p_total[i] + 0.0001) |
|
|
143 |
F1 = 2 * PR * SN / (PR + SN + 0.0001) |
|
|
144 |
logger.info('Epoch {}: Type Accuracy of {:>5} ({}) : {:.2f} {:.2f} {:.2f}'.format( |
|
|
145 |
epoch, |
|
|
146 |
vartype_classes[i], class_total[i], |
|
|
147 |
SN, PR, F1)) |
|
|
148 |
logger.info('Epoch {}: Type Accuracy of the network on the {} test candidates: {:.4f} %'.format( |
|
|
149 |
epoch, sum(class_total), ( |
|
|
150 |
100 * sum(class_correct) / float(sum(class_total))))) |
|
|
151 |
|
|
|
152 |
for i in range(4): |
|
|
153 |
SN = 100 * len_class_correct[i] / (len_class_total[i] + 0.0001) |
|
|
154 |
PR = 100 * len_class_correct[i] / (len_class_p_total[i] + 0.0001) |
|
|
155 |
F1 = 2 * PR * SN / (PR + SN + 0.0001) |
|
|
156 |
logger.info('Epoch {}: Length Accuracy of {:>5} ({}) : {:.2f} {:.2f} {:.2f}'.format( |
|
|
157 |
epoch, i, len_class_total[i], |
|
|
158 |
SN, PR, F1)) |
|
|
159 |
logger.info('Epoch {}: Length Accuracy of the network on the {} test candidates: {:.4f} %'.format( |
|
|
160 |
epoch, sum(len_class_total), ( |
|
|
161 |
100 * sum(len_class_correct) / float(sum(len_class_total))))) |
|
|
162 |
|
|
|
163 |
net.train() |
|
|
164 |
|
|
|
165 |
|
|
|
166 |
class SubsetNoneSampler(torch.utils.data.sampler.Sampler): |
|
|
167 |
|
|
|
168 |
def __init__(self, none_indices, var_indices, none_count): |
|
|
169 |
self.none_indices = none_indices |
|
|
170 |
self.var_indices = var_indices |
|
|
171 |
self.none_count = none_count |
|
|
172 |
self.current_none_id = 0 |
|
|
173 |
|
|
|
174 |
def __iter__(self): |
|
|
175 |
logger = logging.getLogger(SubsetNoneSampler.__iter__.__name__) |
|
|
176 |
if self.current_none_id > (len(self.none_indices) - self.none_count): |
|
|
177 |
this_round_nones = self.none_indices[self.current_none_id:] |
|
|
178 |
self.none_indices = list(map(lambda i: self.none_indices[i], |
|
|
179 |
torch.randperm(len(self.none_indices)).tolist())) |
|
|
180 |
self.current_none_id = self.none_count - len(this_round_nones) |
|
|
181 |
this_round_nones += self.none_indices[0:self.current_none_id] |
|
|
182 |
else: |
|
|
183 |
this_round_nones = self.none_indices[ |
|
|
184 |
self.current_none_id:self.current_none_id + self.none_count] |
|
|
185 |
self.current_none_id += self.none_count |
|
|
186 |
|
|
|
187 |
current_indices = this_round_nones + self.var_indices |
|
|
188 |
ret = iter(map(lambda i: current_indices[i], |
|
|
189 |
torch.randperm(len(current_indices)))) |
|
|
190 |
return ret |
|
|
191 |
|
|
|
192 |
def __len__(self): |
|
|
193 |
return len(self.var_indices) + self.none_count |
|
|
194 |
|
|
|
195 |
|
|
|
196 |
def train_neusomatic(candidates_tsv, validation_candidates_tsv, out_dir, checkpoint, |
|
|
197 |
num_threads, batch_size, max_epochs, learning_rate, lr_drop_epochs, |
|
|
198 |
lr_drop_ratio, momentum, boost_none, none_count_scale, |
|
|
199 |
max_load_candidates, coverage_thr, save_freq, ensemble, |
|
|
200 |
merged_candidates_per_tsv, merged_max_num_tsvs, overwrite_merged_tsvs, |
|
|
201 |
train_split_len, |
|
|
202 |
normalize_channels, |
|
|
203 |
use_cuda): |
|
|
204 |
logger = logging.getLogger(train_neusomatic.__name__) |
|
|
205 |
|
|
|
206 |
logger.info("----------------Train NeuSomatic Network-------------------") |
|
|
207 |
logger.info("PyTorch Version: {}".format(torch.__version__)) |
|
|
208 |
logger.info("Torchvision Version: {}".format(torchvision.__version__)) |
|
|
209 |
|
|
|
210 |
if not os.path.exists(out_dir): |
|
|
211 |
os.mkdir(out_dir) |
|
|
212 |
|
|
|
213 |
if not use_cuda: |
|
|
214 |
torch.set_num_threads(num_threads) |
|
|
215 |
|
|
|
216 |
data_transform = matrix_transform((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
|
217 |
num_channels = 119 if ensemble else 26 |
|
|
218 |
net = NeuSomaticNet(num_channels) |
|
|
219 |
if use_cuda: |
|
|
220 |
logger.info("GPU training!") |
|
|
221 |
net.cuda() |
|
|
222 |
else: |
|
|
223 |
logger.info("CPU training!") |
|
|
224 |
|
|
|
225 |
if torch.cuda.device_count() > 1: |
|
|
226 |
logger.info("We use {} GPUs!".format(torch.cuda.device_count())) |
|
|
227 |
net = nn.DataParallel(net) |
|
|
228 |
|
|
|
229 |
if not os.path.exists("{}/models/".format(out_dir)): |
|
|
230 |
os.mkdir("{}/models/".format(out_dir)) |
|
|
231 |
|
|
|
232 |
if checkpoint: |
|
|
233 |
logger.info( |
|
|
234 |
"Load pretrained model from checkpoint {}".format(checkpoint)) |
|
|
235 |
pretrained_dict = torch.load( |
|
|
236 |
checkpoint, map_location=lambda storage, loc: storage) |
|
|
237 |
pretrained_state_dict = pretrained_dict["state_dict"] |
|
|
238 |
tag = pretrained_dict["tag"] |
|
|
239 |
sofar_epochs = pretrained_dict["epoch"] |
|
|
240 |
logger.info( |
|
|
241 |
"sofar_epochs from pretrained checkpoint: {}".format(sofar_epochs)) |
|
|
242 |
coverage_thr = pretrained_dict["coverage_thr"] |
|
|
243 |
logger.info( |
|
|
244 |
"Override coverage_thr from pretrained checkpoint: {}".format(coverage_thr)) |
|
|
245 |
if "normalize_channels" in pretrained_dict: |
|
|
246 |
normalize_channels = pretrained_dict["normalize_channels"] |
|
|
247 |
else: |
|
|
248 |
normalize_channels = False |
|
|
249 |
logger.info( |
|
|
250 |
"Override normalize_channels from pretrained checkpoint: {}".format(normalize_channels)) |
|
|
251 |
prev_epochs = sofar_epochs + 1 |
|
|
252 |
model_dict = net.state_dict() |
|
|
253 |
# 1. filter out unnecessary keys |
|
|
254 |
# pretrained_state_dict = { |
|
|
255 |
# k: v for k, v in pretrained_state_dict.items() if k in model_dict} |
|
|
256 |
if "module." in list(pretrained_state_dict.keys())[0] and "module." not in list(model_dict.keys())[0]: |
|
|
257 |
pretrained_state_dict = {k.split("module.")[1]: v for k, v in pretrained_state_dict.items( |
|
|
258 |
) if k.split("module.")[1] in model_dict} |
|
|
259 |
elif "module." not in list(pretrained_state_dict.keys())[0] and "module." in list(model_dict.keys())[0]: |
|
|
260 |
pretrained_state_dict = { |
|
|
261 |
("module." + k): v for k, v in pretrained_state_dict.items() if ("module." + k) in model_dict} |
|
|
262 |
else: |
|
|
263 |
pretrained_state_dict = {k: v for k, |
|
|
264 |
v in pretrained_state_dict.items() if k in model_dict} |
|
|
265 |
# 2. overwrite entries in the existing state dict |
|
|
266 |
model_dict.update(pretrained_state_dict) |
|
|
267 |
# 3. load the new state dict |
|
|
268 |
net.load_state_dict(pretrained_state_dict) |
|
|
269 |
else: |
|
|
270 |
prev_epochs = 0 |
|
|
271 |
time_now = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S") |
|
|
272 |
tag = "neusomatic_{}".format(time_now) |
|
|
273 |
logger.info("tag: {}".format(tag)) |
|
|
274 |
|
|
|
275 |
shuffle(candidates_tsv) |
|
|
276 |
|
|
|
277 |
if len(candidates_tsv) > merged_max_num_tsvs: |
|
|
278 |
candidates_tsv = merge_tsvs(input_tsvs=candidates_tsv, out=out_dir, |
|
|
279 |
candidates_per_tsv=merged_candidates_per_tsv, |
|
|
280 |
max_num_tsvs=merged_max_num_tsvs, |
|
|
281 |
overwrite_merged_tsvs=overwrite_merged_tsvs, |
|
|
282 |
keep_none_types=True) |
|
|
283 |
|
|
|
284 |
Ls = [] |
|
|
285 |
for tsv in candidates_tsv: |
|
|
286 |
idx = pickle.load(open(tsv + ".idx", "rb")) |
|
|
287 |
Ls.append(len(idx) - 1) |
|
|
288 |
|
|
|
289 |
Ls, candidates_tsv = list(zip( |
|
|
290 |
*sorted(zip(Ls, candidates_tsv), key=lambda x: x[0], reverse=True))) |
|
|
291 |
|
|
|
292 |
train_split_tsvs = [] |
|
|
293 |
current_L = 0 |
|
|
294 |
current_split_tsvs = [] |
|
|
295 |
for i, (L, tsv) in enumerate(zip(Ls, candidates_tsv)): |
|
|
296 |
current_L += L |
|
|
297 |
current_split_tsvs.append(tsv) |
|
|
298 |
if current_L >= train_split_len or (i == len(candidates_tsv) - 1 and current_L > 0): |
|
|
299 |
logger.info("tsvs in split {}: {}".format( |
|
|
300 |
len(train_split_tsvs), current_split_tsvs)) |
|
|
301 |
train_split_tsvs.append(current_split_tsvs) |
|
|
302 |
current_L = 0 |
|
|
303 |
current_split_tsvs = [] |
|
|
304 |
|
|
|
305 |
assert sum(map(lambda x: len(x), train_split_tsvs)) == len(candidates_tsv) |
|
|
306 |
train_sets = [] |
|
|
307 |
none_counts = [] |
|
|
308 |
var_counts = [] |
|
|
309 |
none_indices_ = [] |
|
|
310 |
var_indices_ = [] |
|
|
311 |
samplers = [] |
|
|
312 |
for split_i, tsvs in enumerate(train_split_tsvs): |
|
|
313 |
train_set = NeuSomaticDataset(roots=tsvs, |
|
|
314 |
max_load_candidates=int( |
|
|
315 |
max_load_candidates * len(tsvs) / float(len(candidates_tsv))), |
|
|
316 |
transform=data_transform, is_test=False, |
|
|
317 |
num_threads=num_threads, coverage_thr=coverage_thr, |
|
|
318 |
normalize_channels=normalize_channels) |
|
|
319 |
train_sets.append(train_set) |
|
|
320 |
none_indices = train_set.get_none_indices() |
|
|
321 |
var_indices = train_set.get_var_indices() |
|
|
322 |
if none_indices: |
|
|
323 |
none_indices = list(map(lambda i: none_indices[i], |
|
|
324 |
torch.randperm(len(none_indices)).tolist())) |
|
|
325 |
logger.info( |
|
|
326 |
"Non-somatic candidates in split {}: {}".format(split_i, len(none_indices))) |
|
|
327 |
if var_indices: |
|
|
328 |
var_indices = list(map(lambda i: var_indices[i], |
|
|
329 |
torch.randperm(len(var_indices)).tolist())) |
|
|
330 |
logger.info("Somatic candidates in split {}: {}".format( |
|
|
331 |
split_i, len(var_indices))) |
|
|
332 |
none_count = max(min(len(none_indices), len( |
|
|
333 |
var_indices) * none_count_scale), 1) |
|
|
334 |
logger.info( |
|
|
335 |
"Non-somatic considered in each epoch of split {}: {}".format(split_i, none_count)) |
|
|
336 |
|
|
|
337 |
sampler = SubsetNoneSampler(none_indices, var_indices, none_count) |
|
|
338 |
samplers.append(sampler) |
|
|
339 |
none_counts.append(none_count) |
|
|
340 |
var_counts.append(len(var_indices)) |
|
|
341 |
var_indices_.append(var_indices) |
|
|
342 |
none_indices_.append(none_indices) |
|
|
343 |
logger.info("# Total Train cadidates: {}".format( |
|
|
344 |
sum(map(lambda x: len(x), train_sets)))) |
|
|
345 |
|
|
|
346 |
if validation_candidates_tsv: |
|
|
347 |
validation_set = NeuSomaticDataset(roots=validation_candidates_tsv, |
|
|
348 |
max_load_candidates=max_load_candidates, |
|
|
349 |
transform=data_transform, is_test=True, |
|
|
350 |
num_threads=num_threads, coverage_thr=coverage_thr, |
|
|
351 |
normalize_channels=normalize_channels) |
|
|
352 |
validation_loader = torch.utils.data.DataLoader(validation_set, |
|
|
353 |
batch_size=batch_size, shuffle=True, |
|
|
354 |
num_workers=num_threads, pin_memory=True) |
|
|
355 |
logger.info("#Validation candidates: {}".format(len(validation_set))) |
|
|
356 |
|
|
|
357 |
count_class_t = [0] * 4 |
|
|
358 |
count_class_l = [0] * 4 |
|
|
359 |
for train_set in train_sets: |
|
|
360 |
for i in range(4): |
|
|
361 |
count_class_t[i] += train_set.count_class_t[i] |
|
|
362 |
count_class_l[i] += train_set.count_class_l[i] |
|
|
363 |
|
|
|
364 |
weights_type, weights_length = make_weights_for_balanced_classes( |
|
|
365 |
count_class_t, count_class_l, 4, 4, sum(none_counts)) |
|
|
366 |
|
|
|
367 |
weights_type[2] *= boost_none |
|
|
368 |
weights_length[0] *= boost_none |
|
|
369 |
|
|
|
370 |
logger.info("weights_type:{}, weights_length:{}".format( |
|
|
371 |
weights_type, weights_length)) |
|
|
372 |
|
|
|
373 |
loss_s = [] |
|
|
374 |
gradients = torch.FloatTensor(weights_type) |
|
|
375 |
gradients2 = torch.FloatTensor(weights_length) |
|
|
376 |
if use_cuda: |
|
|
377 |
gradients = gradients.cuda() |
|
|
378 |
gradients2 = gradients2.cuda() |
|
|
379 |
criterion_crossentropy = nn.CrossEntropyLoss(gradients) |
|
|
380 |
criterion_crossentropy2 = nn.CrossEntropyLoss(gradients2) |
|
|
381 |
criterion_smoothl1 = nn.SmoothL1Loss() |
|
|
382 |
optimizer = optim.SGD( |
|
|
383 |
net.parameters(), lr=learning_rate, momentum=momentum) |
|
|
384 |
|
|
|
385 |
net.train() |
|
|
386 |
len_train_set = sum(none_counts) + sum(var_counts) |
|
|
387 |
logger.info("Number of candidater per epoch: {}".format(len_train_set)) |
|
|
388 |
print_freq = max(1, int(len_train_set / float(batch_size) / 4.0)) |
|
|
389 |
curr_epoch = prev_epochs |
|
|
390 |
torch.save({"state_dict": net.state_dict(), |
|
|
391 |
"tag": tag, |
|
|
392 |
"epoch": curr_epoch, |
|
|
393 |
"coverage_thr": coverage_thr, |
|
|
394 |
"normalize_channels": normalize_channels}, |
|
|
395 |
'{}/models/checkpoint_{}_epoch{}.pth'.format(out_dir, tag, curr_epoch)) |
|
|
396 |
|
|
|
397 |
if len(train_sets) == 1: |
|
|
398 |
train_sets[0].open_candidate_tsvs() |
|
|
399 |
train_loader = torch.utils.data.DataLoader(train_sets[0], |
|
|
400 |
batch_size=batch_size, |
|
|
401 |
num_workers=num_threads, pin_memory=True, |
|
|
402 |
sampler=samplers[0]) |
|
|
403 |
# loop over the dataset multiple times |
|
|
404 |
n_epoch = 0 |
|
|
405 |
for epoch in range(max_epochs - prev_epochs): |
|
|
406 |
n_epoch += 1 |
|
|
407 |
running_loss = 0.0 |
|
|
408 |
i_ = 0 |
|
|
409 |
for split_i, train_set in enumerate(train_sets): |
|
|
410 |
if len(train_sets) > 1: |
|
|
411 |
train_set.open_candidate_tsvs() |
|
|
412 |
train_loader = torch.utils.data.DataLoader(train_set, |
|
|
413 |
batch_size=batch_size, |
|
|
414 |
num_workers=num_threads, pin_memory=True, |
|
|
415 |
sampler=samplers[split_i]) |
|
|
416 |
for i, data in enumerate(train_loader, 0): |
|
|
417 |
i_ += 1 |
|
|
418 |
# get the inputs |
|
|
419 |
(inputs, labels, var_pos_s, var_len_s, _), _ = data |
|
|
420 |
# wrap them in Variable |
|
|
421 |
inputs, labels, var_pos_s, var_len_s = Variable(inputs), Variable( |
|
|
422 |
labels), Variable(var_pos_s), Variable(var_len_s) |
|
|
423 |
if use_cuda: |
|
|
424 |
inputs, labels, var_pos_s, var_len_s = inputs.cuda( |
|
|
425 |
), labels.cuda(), var_pos_s.cuda(), var_len_s.cuda() |
|
|
426 |
|
|
|
427 |
# zero the parameter gradients |
|
|
428 |
optimizer.zero_grad() |
|
|
429 |
|
|
|
430 |
outputs, _ = net(inputs) |
|
|
431 |
[outputs_classification, outputs_pos, outputs_len] = outputs |
|
|
432 |
var_len_labels = Variable( |
|
|
433 |
torch.LongTensor(var_len_s.cpu().data.numpy())) |
|
|
434 |
if use_cuda: |
|
|
435 |
var_len_labels = var_len_labels.cuda() |
|
|
436 |
loss = criterion_crossentropy(outputs_classification, labels) + 1 * criterion_smoothl1( |
|
|
437 |
outputs_pos.squeeze(1), var_pos_s[:, 1] |
|
|
438 |
) + 1 * criterion_crossentropy2(outputs_len, var_len_labels) |
|
|
439 |
|
|
|
440 |
loss.backward() |
|
|
441 |
optimizer.step() |
|
|
442 |
loss_s.append(loss.data) |
|
|
443 |
|
|
|
444 |
running_loss += loss.data |
|
|
445 |
if i_ % print_freq == print_freq - 1: |
|
|
446 |
logger.info('epoch: {}, iter: {:>7}, lr: {}, loss: {:.5f}'.format( |
|
|
447 |
n_epoch + prev_epochs, len(loss_s), |
|
|
448 |
learning_rate, running_loss / print_freq)) |
|
|
449 |
running_loss = 0.0 |
|
|
450 |
if len(train_sets) > 1: |
|
|
451 |
train_set.close_candidate_tsvs() |
|
|
452 |
|
|
|
453 |
curr_epoch = n_epoch + prev_epochs |
|
|
454 |
if curr_epoch % save_freq == 0: |
|
|
455 |
torch.save({"state_dict": net.state_dict(), |
|
|
456 |
"tag": tag, |
|
|
457 |
"epoch": curr_epoch, |
|
|
458 |
"coverage_thr": coverage_thr, |
|
|
459 |
"normalize_channels": normalize_channels, |
|
|
460 |
}, '{}/models/checkpoint_{}_epoch{}.pth'.format(out_dir, tag, curr_epoch)) |
|
|
461 |
if validation_candidates_tsv: |
|
|
462 |
test(net, curr_epoch, validation_loader, use_cuda) |
|
|
463 |
if curr_epoch % lr_drop_epochs == 0: |
|
|
464 |
learning_rate *= lr_drop_ratio |
|
|
465 |
optimizer = optim.SGD( |
|
|
466 |
net.parameters(), lr=learning_rate, momentum=momentum) |
|
|
467 |
logger.info('Finished Training') |
|
|
468 |
|
|
|
469 |
if len(train_sets) == 1: |
|
|
470 |
train_sets[0].close_candidate_tsvs() |
|
|
471 |
|
|
|
472 |
curr_epoch = n_epoch + prev_epochs |
|
|
473 |
torch.save({"state_dict": net.state_dict(), |
|
|
474 |
"tag": tag, |
|
|
475 |
"epoch": curr_epoch, |
|
|
476 |
"coverage_thr": coverage_thr, |
|
|
477 |
"normalize_channels": normalize_channels, |
|
|
478 |
}, '{}/models/checkpoint_{}_epoch{}.pth'.format( |
|
|
479 |
out_dir, tag, curr_epoch)) |
|
|
480 |
if validation_candidates_tsv: |
|
|
481 |
test(net, curr_epoch, validation_loader, use_cuda) |
|
|
482 |
logger.info("Total Epochs: {}".format(curr_epoch)) |
|
|
483 |
logger.info("Total Epochs: {}".format(curr_epoch)) |
|
|
484 |
|
|
|
485 |
logger.info("Training is Done.") |
|
|
486 |
|
|
|
487 |
return '{}/models/checkpoint_{}_epoch{}.pth'.format(out_dir, tag, curr_epoch) |
|
|
488 |
|
|
|
489 |
if __name__ == '__main__': |
|
|
490 |
|
|
|
491 |
FORMAT = '%(levelname)s %(asctime)-15s %(name)-20s %(message)s' |
|
|
492 |
logging.basicConfig(level=logging.INFO, format=FORMAT) |
|
|
493 |
logger = logging.getLogger(__name__) |
|
|
494 |
|
|
|
495 |
parser = argparse.ArgumentParser( |
|
|
496 |
description='simple call variants from bam') |
|
|
497 |
parser.add_argument('--candidates_tsv', nargs="*", |
|
|
498 |
help=' train candidate tsv files', required=True) |
|
|
499 |
parser.add_argument('--out', type=str, |
|
|
500 |
help='output directory', required=True) |
|
|
501 |
parser.add_argument('--checkpoint', type=str, |
|
|
502 |
help='pretrained network model checkpoint path', default=None) |
|
|
503 |
parser.add_argument('--validation_candidates_tsv', nargs="*", |
|
|
504 |
help=' validation candidate tsv files', default=[]) |
|
|
505 |
parser.add_argument('--ensemble', |
|
|
506 |
help='Enable training for ensemble mode', |
|
|
507 |
action="store_true") |
|
|
508 |
parser.add_argument('--num_threads', type=int, |
|
|
509 |
help='number of threads', default=1) |
|
|
510 |
parser.add_argument('--batch_size', type=int, |
|
|
511 |
help='batch size', default=1000) |
|
|
512 |
parser.add_argument('--max_epochs', type=int, |
|
|
513 |
help='maximum number of training epochs', default=1000) |
|
|
514 |
parser.add_argument('--lr', type=float, help='learning_rate', default=0.01) |
|
|
515 |
parser.add_argument('--lr_drop_epochs', type=int, |
|
|
516 |
help='number of epochs to drop learning rate', default=400) |
|
|
517 |
parser.add_argument('--lr_drop_ratio', type=float, |
|
|
518 |
help='learning rate drop scale', default=0.1) |
|
|
519 |
parser.add_argument('--momentum', type=float, |
|
|
520 |
help='SGD momentum', default=0.9) |
|
|
521 |
parser.add_argument('--boost_none', type=float, |
|
|
522 |
help='the amount to boost none-somatic classification weight', default=100) |
|
|
523 |
parser.add_argument('--none_count_scale', type=float, |
|
|
524 |
help='ratio of the none/somatic canidates to use in each training epoch \ |
|
|
525 |
(the none candidate will be subsampled in each epoch based on this ratio', |
|
|
526 |
default=2) |
|
|
527 |
parser.add_argument('--max_load_candidates', type=int, |
|
|
528 |
help='maximum candidates to load in memory', default=1000000) |
|
|
529 |
parser.add_argument('--save_freq', type=int, |
|
|
530 |
help='the frequency of saving checkpoints in terms of # epochs', default=50) |
|
|
531 |
parser.add_argument('--merged_candidates_per_tsv', type=int, |
|
|
532 |
help='Maximum number of candidates in each merged tsv file ', default=10000000) |
|
|
533 |
parser.add_argument('--merged_max_num_tsvs', type=int, |
|
|
534 |
help='Maximum number of merged tsv files \ |
|
|
535 |
(higher priority than merged_candidates_per_tsv)', default=10) |
|
|
536 |
parser.add_argument('--overwrite_merged_tsvs', |
|
|
537 |
help='if OUT/merged_tsvs/ folder exists overwrite the merged tsvs', |
|
|
538 |
action="store_true") |
|
|
539 |
parser.add_argument('--train_split_len', type=int, |
|
|
540 |
help='Maximum number of candidates used in each split of training (>=merged_candidates_per_tsv)', |
|
|
541 |
default=10000000) |
|
|
542 |
parser.add_argument('--coverage_thr', type=int, |
|
|
543 |
help='maximum coverage threshold to be used for network input \ |
|
|
544 |
normalization. \ |
|
|
545 |
Will be overridden if pretrained model is provided\ |
|
|
546 |
For ~50x WGS, coverage_thr=100 should work. \ |
|
|
547 |
For higher coverage WES, coverage_thr=300 should work.', default=100) |
|
|
548 |
parser.add_argument('--normalize_channels', |
|
|
549 |
help='normalize BQ, MQ, and other bam-info channels by frequency of observed alleles. \ |
|
|
550 |
Will be overridden if pretrained model is provided', |
|
|
551 |
action="store_true") |
|
|
552 |
args = parser.parse_args() |
|
|
553 |
|
|
|
554 |
logger.info(args) |
|
|
555 |
|
|
|
556 |
use_cuda = torch.cuda.is_available() |
|
|
557 |
logger.info("use_cuda: {}".format(use_cuda)) |
|
|
558 |
|
|
|
559 |
try: |
|
|
560 |
checkpoint = train_neusomatic(args.candidates_tsv, args.validation_candidates_tsv, |
|
|
561 |
args.out, args.checkpoint, args.num_threads, args.batch_size, |
|
|
562 |
args.max_epochs, |
|
|
563 |
args.lr, args.lr_drop_epochs, args.lr_drop_ratio, args.momentum, |
|
|
564 |
args.boost_none, args.none_count_scale, |
|
|
565 |
args.max_load_candidates, args.coverage_thr, args.save_freq, |
|
|
566 |
args.ensemble, |
|
|
567 |
args.merged_candidates_per_tsv, args.merged_max_num_tsvs, |
|
|
568 |
args.overwrite_merged_tsvs, args.train_split_len, |
|
|
569 |
args.normalize_channels, |
|
|
570 |
use_cuda) |
|
|
571 |
except Exception as e: |
|
|
572 |
logger.error(traceback.format_exc()) |
|
|
573 |
logger.error("Aborting!") |
|
|
574 |
logger.error( |
|
|
575 |
"train.py failure on arguments: {}".format(args)) |
|
|
576 |
raise e |