|
a |
|
b/main.py |
|
|
1 |
import time |
|
|
2 |
import copy |
|
|
3 |
import pandas as pd |
|
|
4 |
import torch |
|
|
5 |
from torch.autograd import Variable |
|
|
6 |
from densenet import densenet169 |
|
|
7 |
from utils import plot_training, n_p, get_count |
|
|
8 |
from train import train_model, get_metrics |
|
|
9 |
from pipeline import get_study_level_data, get_dataloaders |
|
|
10 |
|
|
|
11 |
# #### load study level dict data |
|
|
12 |
study_data = get_study_level_data(study_type='XR_WRIST') |
|
|
13 |
|
|
|
14 |
# #### Create dataloaders pipeline |
|
|
15 |
data_cat = ['train', 'valid'] # data categories |
|
|
16 |
dataloaders = get_dataloaders(study_data, batch_size=1) |
|
|
17 |
dataset_sizes = {x: len(study_data[x]) for x in data_cat} |
|
|
18 |
|
|
|
19 |
# #### Build model |
|
|
20 |
# tai = total abnormal images, tni = total normal images |
|
|
21 |
tai = {x: get_count(study_data[x], 'positive') for x in data_cat} |
|
|
22 |
tni = {x: get_count(study_data[x], 'negative') for x in data_cat} |
|
|
23 |
Wt1 = {x: n_p(tni[x] / (tni[x] + tai[x])) for x in data_cat} |
|
|
24 |
Wt0 = {x: n_p(tai[x] / (tni[x] + tai[x])) for x in data_cat} |
|
|
25 |
|
|
|
26 |
print('tai:', tai) |
|
|
27 |
print('tni:', tni, '\n') |
|
|
28 |
print('Wt0 train:', Wt0['train']) |
|
|
29 |
print('Wt0 valid:', Wt0['valid']) |
|
|
30 |
print('Wt1 train:', Wt1['train']) |
|
|
31 |
print('Wt1 valid:', Wt1['valid']) |
|
|
32 |
|
|
|
33 |
class Loss(torch.nn.modules.Module): |
|
|
34 |
def __init__(self, Wt1, Wt0): |
|
|
35 |
super(Loss, self).__init__() |
|
|
36 |
self.Wt1 = Wt1 |
|
|
37 |
self.Wt0 = Wt0 |
|
|
38 |
|
|
|
39 |
def forward(self, inputs, targets, phase): |
|
|
40 |
loss = - (self.Wt1[phase] * targets * inputs.log() + self.Wt0[phase] * (1 - targets) * (1 - inputs).log()) |
|
|
41 |
return loss |
|
|
42 |
|
|
|
43 |
model = densenet169(pretrained=True) |
|
|
44 |
model = model.cuda() |
|
|
45 |
|
|
|
46 |
criterion = Loss(Wt1, Wt0) |
|
|
47 |
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) |
|
|
48 |
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=1, verbose=True) |
|
|
49 |
|
|
|
50 |
# #### Train model |
|
|
51 |
model = train_model(model, criterion, optimizer, dataloaders, scheduler, dataset_sizes, num_epochs=5) |
|
|
52 |
|
|
|
53 |
torch.save(model.state_dict(), 'models/model.pth') |
|
|
54 |
|
|
|
55 |
get_metrics(model, criterion, dataloaders, dataset_sizes) |