Diff of /main.py [000000] .. [df6751]

Switch to unified view

a b/main.py
1
import time
2
import copy
3
import pandas as pd
4
import torch
5
from torch.autograd import Variable
6
from densenet import densenet169
7
from utils import plot_training, n_p, get_count
8
from train import train_model, get_metrics
9
from pipeline import get_study_level_data, get_dataloaders
10
11
# #### load study level dict data
12
study_data = get_study_level_data(study_type='XR_WRIST')
13
14
# #### Create dataloaders pipeline
15
data_cat = ['train', 'valid'] # data categories
16
dataloaders = get_dataloaders(study_data, batch_size=1)
17
dataset_sizes = {x: len(study_data[x]) for x in data_cat}
18
19
# #### Build model
20
# tai = total abnormal images, tni = total normal images
21
tai = {x: get_count(study_data[x], 'positive') for x in data_cat}
22
tni = {x: get_count(study_data[x], 'negative') for x in data_cat}
23
Wt1 = {x: n_p(tni[x] / (tni[x] + tai[x])) for x in data_cat}
24
Wt0 = {x: n_p(tai[x] / (tni[x] + tai[x])) for x in data_cat}
25
26
print('tai:', tai)
27
print('tni:', tni, '\n')
28
print('Wt0 train:', Wt0['train'])
29
print('Wt0 valid:', Wt0['valid'])
30
print('Wt1 train:', Wt1['train'])
31
print('Wt1 valid:', Wt1['valid'])
32
33
class Loss(torch.nn.modules.Module):
34
    def __init__(self, Wt1, Wt0):
35
        super(Loss, self).__init__()
36
        self.Wt1 = Wt1
37
        self.Wt0 = Wt0
38
        
39
    def forward(self, inputs, targets, phase):
40
        loss = - (self.Wt1[phase] * targets * inputs.log() + self.Wt0[phase] * (1 - targets) * (1 - inputs).log())
41
        return loss
42
43
model = densenet169(pretrained=True)
44
model = model.cuda()
45
46
criterion = Loss(Wt1, Wt0)
47
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
48
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=1, verbose=True)
49
50
# #### Train model
51
model = train_model(model, criterion, optimizer, dataloaders, scheduler, dataset_sizes, num_epochs=5)
52
53
torch.save(model.state_dict(), 'models/model.pth')
54
55
get_metrics(model, criterion, dataloaders, dataset_sizes)