Diff of /app/backend/app.py [000000] .. [69507b]

Switch to unified view

a b/app/backend/app.py
1
import io
2
import json
3
from typing import Optional # required for "Optional[type]"
4
from PIL import Image
5
import pandas as pd
6
from flask import Flask, request,send_from_directory
7
import os
8
import cv2
9
import pydicom
10
import png
11
import numpy as np
12
import matplotlib.pyplot as plt
13
import sys
14
15
import torch,torchvision
16
from torch import nn
17
from torch import Tensor
18
from torchvision import models
19
import torchvision.transforms as transforms
20
import torch
21
import torchvision
22
from Utils import use_gradcam
23
24
from flask_cors import CORS
25
26
from pathlib import Path
27
28
app = Flask(__name__)
29
app.config["DEBUG"] = True
30
CORS(app)
31
UPLOAD_FOLDER = './input_folder'
32
GRADCAM_FOLDER='./gradcam_imgs'
33
ALLOWED_EXTENSIONS = {'png', 'dcm'}
34
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
35
36
################ START_OF_MODEL ################
37
#Code modified and taken from Andrea de Luca (https://bit.ly/2YXW6xN)
38
device = torch.device("cpu")
39
40
class Flatten(nn.Module):
41
    "Flatten `x` to a single dimension, often used at the end of a model. `full` for rank-1 tensor"
42
    def __init__(self, full:bool=False):
43
        super().__init__()
44
        self.full = full
45
46
    def forward(self, x):
47
        return x.view(-1) if self.full else x.view(x.size(0), -1)
48
49
class AdaptiveConcatPool2d(nn.Module):
50
    "Layer that concats `AdaptiveAvgPool2d` and `AdaptiveMaxPool2d`." # from pytorch
51
    def __init__(self, sz:Optional[int]=None):
52
        "Output will be 2*sz or 2 if sz is None"
53
        super().__init__()
54
        self.output_size = sz or 1
55
        self.ap = nn.AdaptiveAvgPool2d(self.output_size)
56
        self.mp = nn.AdaptiveMaxPool2d(self.output_size)
57
    def forward(self, x):
58
        return torch.cat([self.mp(x), self.ap(x)], 1)
59
60
def myhead(nf, nc):
61
    '''
62
    Inputs: nf=  # of in_features in the 4th layer , nc= # of classes
63
    '''
64
    return \
65
    nn.Sequential(        # the dropout is needed otherwise you cannot load the weights
66
            AdaptiveConcatPool2d(),
67
            Flatten(),
68
            nn.BatchNorm1d(nf,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True),
69
            nn.Dropout(p=0.25,inplace=False),
70
            nn.Linear(nf, 512,bias=True),
71
            nn.ReLU(inplace=True),
72
            nn.BatchNorm1d(512,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True),
73
            nn.Dropout(p=0.5,inplace=False),
74
            nn.Linear(512, nc,bias=True),
75
        )
76
77
78
my_model=torchvision.models.resnet34()
79
modules=list(my_model.children())
80
modules.pop(-1)
81
modules.pop(-1)
82
temp=nn.Sequential(nn.Sequential(*modules))
83
tempchildren=list(temp.children())
84
85
#append the special fastai head
86
#Configured according to Model Architecture
87
88
tempchildren.append(myhead(1024,3))
89
model_r34=nn.Sequential(*tempchildren)
90
91
#LOAD MODEL
92
state = torch.load(Path('corona_resnet34.pth').resolve(),map_location=torch.device('cpu'))
93
model_r34.load_state_dict(state['model'])
94
95
#important to set to evaluation mode
96
model_r34.eval()
97
98
################ END_OF_MODEL ################
99
100
test_transforms = transforms.Compose([
101
    transforms.Resize(512),
102
    transforms.CenterCrop(512),
103
    transforms.ToTensor(),
104
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
105
                    std=[0.229, 0.224, 0.225])
106
])
107
108
#accepts png files
109
def predict_image(image):
110
    softmaxer = torch.nn.Softmax(dim=1)
111
    image_tensor = Image.open(image)
112
    image_tensor = image_tensor.convert('RGB')
113
    image_tensor = test_transforms(image_tensor).float()
114
    image_tensor=image_tensor.unsqueeze(0)
115
116
    #convert evaluation to probabilities with softmax
117
    with torch.no_grad(): #turn off backpropagation
118
      processed=softmaxer(model_r34(image_tensor))
119
    return (processed[0]) #return probabilities
120
121
def get_metadata(folder,filename, attribute):
122
    '''
123
    Given a path to folder of images, patient ID, and attribute, return useful meta-data from the corresponding dicom image.
124
    IMPLICITLY Converts dicom image to png in the process and puts to test folder
125
    Returns attribute value, png image (implicit)
126
    '''
127
    ds=pydicom.dcmread(folder+'/'+filename+'.dcm')
128
129
    #implicit DICOM -> PNG conversion
130
    shape = ds.pixel_array.shape
131
    # Convert to float to avoid overflow or underflow losses.
132
    image_2d = ds.pixel_array.astype(float)
133
    # Rescaling grey scale between 0-255
134
    image_2d_scaled = (np.maximum(image_2d,0) / image_2d.max()) * 255.0
135
    # Convert to uint
136
    image_2d_scaled = np.uint8(image_2d_scaled)
137
    # Write the PNG file
138
    with open(os.path.join(folder,filename+'.png'), 'wb') as png_file:
139
        w = png.Writer(shape[1], shape[0], greyscale=True)
140
        w.write(png_file, image_2d_scaled)
141
    try:
142
      attribute_value = getattr(ds, attribute)
143
      return attribute_value
144
    except: return np.NaN
145
146
########Implementation Part###################################
147
#for original images
148
@app.route('/uploads/<path:filename>')
149
def download_file(filename):
150
    #argument is in the form of filename.extension
151
    filename=os.path.splitext(os.path.basename(filename))[0]
152
    if filename[-1]==".":
153
        filename = filename[:-1]
154
    if os.path.exists('./input_folder/{}.png'.format(filename)):
155
        return send_from_directory(UPLOAD_FOLDER,'{}.png'.format(filename), as_attachment=True)
156
157
    if os.path.exists('./input_folder/{}.jpg'.format(filename)):
158
        return send_from_directory(UPLOAD_FOLDER,'{}.jpg'.format(filename), as_attachment=True)
159
160
    if os.path.exists('./input_folder/{}.jpeg'.format(filename)):
161
        return send_from_directory(UPLOAD_FOLDER,'{}.jpeg'.format(filename), as_attachment=True)
162
163
#for gradcam images
164
@app.route('/gradcam/<path:filename>')
165
def download_gradcam_file(filename):
166
    #argument is in the form of filename.extension 
167
    use_gradcam(os.path.join(UPLOAD_FOLDER,filename),GRADCAM_FOLDER,model_r34,test_transforms)
168
    filename=os.path.splitext(os.path.basename(filename))[0]
169
    return send_from_directory(GRADCAM_FOLDER,'(gradcam){}.png'.format(filename), as_attachment=True)       
170
171
@app.route('/', methods=['POST'])
172
def predict():
173
    '''
174
    Inputs: a list of image filenames ending with an extension (e.x. .png) taken from UPLOAD_FOLDER
175
    Returns: a json of predictions_df
176
    '''
177
    if request.method == 'POST':
178
        if not os.path.isdir(UPLOAD_FOLDER):
179
            os.makedirs(UPLOAD_FOLDER)
180
        if not os.path.isdir(GRADCAM_FOLDER):
181
            os.makedirs(GRADCAM_FOLDER)
182
        for filename in os.listdir(UPLOAD_FOLDER):
183
            file_path = os.path.join(UPLOAD_FOLDER, filename)
184
            print(file_path,file=sys.stderr)
185
            try:
186
                if os.path.isfile(file_path) or os.path.islink(file_path):
187
                    os.unlink(file_path)
188
                elif os.path.isdir(file_path):
189
                    os.shutil.rmtree(file_path)
190
            except Exception as e:
191
                print('Failed to delete %s. Reason: %s' % (file_path, e))
192
        for filename in os.listdir(GRADCAM_FOLDER):
193
            file_path = os.path.join(GRADCAM_FOLDER, filename)
194
            try:
195
                if os.path.isfile(file_path) or os.path.islink(file_path):
196
                    os.unlink(file_path)
197
                elif os.path.isdir(file_path):
198
                    os.shutil.rmtree(file_path)
199
            except Exception as e:
200
                print('Failed to delete %s. Reason: %s' % (file_path, e))
201
        data = dict(request.files)
202
        for key in data.keys():
203
            data[key].save(os.path.join(UPLOAD_FOLDER,'{}'.format(data[key].filename)))
204
205
        print("images saved!")
206
207
        #METADATA and CONVERT TO PNG
208
        #list of files to be converted
209
        files = [f[:-4]+'.png' for f in os.listdir(UPLOAD_FOLDER) if f.endswith('.dcm')]
210
        result_df=pd.DataFrame(files,columns=['filename'])
211
212
        #list of essential attributes
213
        attributes = ['PatientID','PatientSex', 'PatientAge', 'ViewPosition']
214
        for a in attributes:
215
            result_df[a] = result_df['filename'].apply(lambda x: get_metadata(UPLOAD_FOLDER, x, a))
216
217
        #PREDICTION
218
        #each image in test_files must be a filename.png from the upload folder
219
        test_files=[file for file in sorted(os.listdir(UPLOAD_FOLDER))if file.endswith(('.png','.jpg','.jpeg'))]
220
        df_results={filename:predict_image(UPLOAD_FOLDER+'/'+filename) for filename in test_files}
221
        print("predictions done")
222
        #OUTPUT DATAFRAMES
223
        predictions_df=pd.DataFrame.from_dict(df_results,orient='index',columns=['covid','nofinding','opacity']).rename_axis('filename').reset_index()
224
        predictions_df['covid']=predictions_df['covid'].apply(lambda x: x.item())
225
        predictions_df['nofinding']=predictions_df['nofinding'].apply(lambda x: x.item())
226
        predictions_df['opacity']=predictions_df['opacity'].apply(lambda x: x.item())
227
        #get the column name of the highest probability
228
        predictions_df['Predicted Label'] =predictions_df[['covid','opacity','nofinding']].idxmax(axis=1)
229
        print("table done")
230
        print("gradcam done")
231
        #predictions_df['filename']=predictions_df['filename'].apply(lambda file: os.path.splitext(file)[0]) #remove .png suffix
232
        #merge result_df and final_df
233
        if result_df.empty:
234
            for a in attributes:
235
                predictions_df[a]="" #include empty columns for proper json formatting
236
            final_df=predictions_df
237
        else:
238
            final_df=pd.merge(result_df,predictions_df[['filename','Predicted Label']], on='filename')
239
            #convert age to int to be used later
240
            final_df['PatientAge'] = pd.to_numeric(final_df['PatientAge'], errors='coerce')
241
242
        print("Generating Results!")
243
        result = final_df.to_json(orient='records') #format: [{"filename":a,... metadata( 'PatientID','PatientSex', 'PatientAge', 'ViewPosition')..., "Predicted Label":f}]
244
        return result;
245
246
if __name__ == '__main__':
247
    app.run(host='0.0.0.0', port=5000)