--- a +++ b/dash/main.py @@ -0,0 +1,767 @@ +# Introducing callbacks + +# -*- coding: utf-8 -*- +import base64 +import time +import dash +import dash_core_components as dcc +import dash_html_components as html +import dash_daq as daq + + +import pandas as pd +import numpy as np +import numpy as np + +import cv2 + + +# prepare the data -- begin + +cases = pd.read_csv('../data/valid-acl.csv', + header=None, + names=['Case', 'Abnormal'], + dtype={'Case': str, 'Abnormal': np.int64} + ) +case_list = cases['Case'].tolist() + + +predictions = pd.read_csv('./val_data.csv') + + +# prepare the data -- end + + +app = dash.Dash(show_undo_redo=False) + +# Boostrap CSS. + +app.css.append_css({'external_url': 'https://codepen.io/amyoshino/pen/jzXypZ.css'}) # noqa: E501 +#app.css.append_css({'external_url': "https://stackpath.bootstrapcdn.com/bootstrap/3.4.1/css/bootstrap.min.css"}) # noqa: E501 + + + +app.layout = html.Div( + html.Div([ + html.Div( + + html.H1(children='Interpretation of MRNet models through Class Activation Maps (CAM)', + className='twelve columns', + style={'text-align': 'center'} + ) + , + className="row", + ), + + html.Div( + [ + html.Div( + [ + html.P('Select a medical case (i.e. a patient)'), + html.Div([ + dcc.Dropdown( + id='cases', + options=[ + {'label': case, 'value': case} for case in case_list + ], + placeholder="Pick a case", + clearable=False + ) + ], + style={'margin-bottom': 20} + ), + ], + className='three columns', + style={'margin-top': '10'} + ), + html.Div([ + html.Div([ + html.Div([ + html.P('Select true labels :'), + dcc.RadioItems( + id="label_radioitems", + options=[ + {'label': 'Positive (ACL tear)', 'value': 'acl'}, + {'label': 'Negative (Normal)', 'value': 'normal'}, + ], + value='acl', + labelStyle={'display': 'inline-block'} + ), + ], + style={'float': 'left', 'width': '45%'} + ), + + html.Div([ + html.P('Select predicted labels :'), + dcc.RadioItems( + id="pred_radioitems", + options=[ + {'label': 'Positive (ACL tear)', 'value': 'acl'}, + {'label': 'Negative (Normal)', 'value': 'normal'}, + ], + value='acl', + labelStyle={'display': 'inline-block'} + ), + ] + ), + ]), + ], + className='six columns' + ), + html.Div([ + html.Div(id="number_of_cases"), + html.Span( + id="label_badge", + className="badge badge-success badge-large", + style={'font-size': '15px'} + ), + ], + className='three columns' + + ) + + ], className="row" + ), + + html.Div([ + + html.Div([ + html.P(id='summary', style={'font-size': '20px'}), + html.Div([ + html.Div('This probability is a weighted average of the three probabilities of tears over each plane', style={'float': 'left', 'font-size': '20px'}), + html.Div('Slide over the slices of each MRI to inspect highlighted regions of tear as depicted by CAMs', style={'float': 'left', 'font-size': '20px'}), + + ], + # style={'text-align': 'center'} + ) + ], + className="twelve columns"), + + ], + className='row' + ), + + html.Hr(), + + html.Div( + [ + html.Div([ + html.Div([ + dcc.Slider(id='slider_axial', updatemode='drag') + ], + style={'margin-right': '5px'} + + ), + html.Hr(), + html.P(id="score_axial", style={'text-align': 'center'}), + html.Div([ + html.Div([ + html.Img( + id="slice_axial", + ), + ], + style={'float': 'left', 'margin-right': '5px'} + ), + html.Div([ + html.Img( + id="cam_axial", + ), + ], + ) + ], + + ), + html.P(id="title_axial", style={'text-align': 'center'}) + ], + className="four columns" + ), + html.Div([ + html.Div([ + dcc.Slider(id='slider_coronal', updatemode='drag') + ], + style={'margin-right': '5px'} + + ), + html.Hr(), + html.P(id="score_coronal", style={'text-align': 'center'}), + html.Div([ + html.Div([ + html.Img( + id="slice_coronal", + ), + ], + style={'float': 'left', 'margin-right': '5px'} + ), + html.Div([ + html.Img( + id="cam_coronal", + ), + ], + ) + ], + + ), + html.P(id="title_coronal", style={'text-align': 'center'}) + ], + className="four columns" + ), + html.Div([ + html.Div([ + dcc.Slider(id='slider_sagittal', updatemode='drag') + ]), + html.Hr(), + html.P(id="score_sagittal", style={ + 'text-align': 'center'}), + html.Div([ + html.Div([ + html.Img( + id="slice_sagittal", + ), + ], + style={'float': 'left', 'margin-right': '5px'} + ), + html.Div([ + html.Img( + id="cam_sagittal", + ), + ], + ) + ], + + ), + html.P(id="title_sagittal", style={'text-align': 'center'}) + + + ], + className="four columns" + ), + + ], + className='row' + ) + ], className='twelve columns') +) + + +# select label --- begin +@app.callback( + dash.dependencies.Output('cases', 'options'), + [ + dash.dependencies.Input('label_radioitems', 'value'), + dash.dependencies.Input('pred_radioitems', 'value'), + ] +) +def set_label(selected_label, selected_pred): + if (selected_label == 'acl') and (selected_pred == 'acl'): + filtered_cases = predictions[(predictions['labels'] == 1) & + (predictions['preds'] >= 0.5)].index.tolist() + + elif (selected_label == 'acl') and (selected_pred == 'normal'): + filtered_cases = predictions[(predictions['labels'] == 1) & + (predictions['preds'] < 0.5)].index.tolist() + + elif (selected_label == 'normal') and (selected_pred == 'acl'): + filtered_cases = predictions[(predictions['labels'] == 0) & + (predictions['preds'] >= 0.5)].index.tolist() + + elif (selected_label == 'normal') and (selected_pred == 'normal'): + filtered_cases = predictions[(predictions['labels'] == 0) & + (predictions['preds'] < 0.5)].index.tolist() + + filtered_cases = [c + 1130 for c in filtered_cases] + options = [{'label': fc, 'value': fc} for fc in filtered_cases] + return options +# select label --- end + +# set badge label --- begin + +@app.callback( + dash.dependencies.Output('label_badge', 'children'), + [ + dash.dependencies.Input('label_radioitems', 'value'), + dash.dependencies.Input('pred_radioitems', 'value'), + ] +) +def set_badge_label(selected_label, selected_pred): + if (selected_label == 'acl') and (selected_pred == 'acl'): + text = 'true positive case' + + elif (selected_label == 'acl') and (selected_pred == 'normal'): + text = 'false negative case' + + elif (selected_label == 'normal') and (selected_pred == 'acl'): + text = 'false positive case' + + elif (selected_label == 'normal') and (selected_pred == 'normal'): + text = 'true negative case' + + return text + +# set badge label --- end + +# set badge color --- begin + +@app.callback( + dash.dependencies.Output('label_badge', 'className'), + [ + dash.dependencies.Input('label_radioitems', 'value'), + dash.dependencies.Input('pred_radioitems', 'value'), + ] +) +def set_badge_color(selected_label, selected_pred): + if (selected_label == 'acl') and (selected_pred == 'acl'): + className = 'badge badge-success' + + elif (selected_label == 'acl') and (selected_pred == 'normal'): + className = 'badge badge-error' + + elif (selected_label == 'normal') and (selected_pred == 'acl'): + className = 'badge badge-error' + + elif (selected_label == 'normal') and (selected_pred == 'normal'): + className = 'badge badge-success' + + return className + +# set badge color --- end + +# set a case value --- begin + +@app.callback( + dash.dependencies.Output('cases', 'value'), + [ + dash.dependencies.Input('label_radioitems', 'value'), + dash.dependencies.Input('pred_radioitems', 'value'), + ] +) +def set_badge_color(selected_label, selected_pred): + if (selected_label == 'acl') and (selected_pred == 'acl'): + filtered_cases = predictions[(predictions['labels'] == 1) & + (predictions['preds'] >= 0.5)].index.tolist() + + elif (selected_label == 'acl') and (selected_pred == 'normal'): + filtered_cases = predictions[(predictions['labels'] == 1) & + (predictions['preds'] < 0.5)].index.tolist() + + elif (selected_label == 'normal') and (selected_pred == 'acl'): + filtered_cases = predictions[(predictions['labels'] == 0) & + (predictions['preds'] >= 0.5)].index.tolist() + + elif (selected_label == 'normal') and (selected_pred == 'normal'): + filtered_cases = predictions[(predictions['labels'] == 0) & + (predictions['preds'] < 0.5)].index.tolist() + + filtered_cases = [c + 1130 for c in filtered_cases] + case_value = np.random.choice(filtered_cases) + return case_value + +# set a case value --- end + + +# set summary --- begin + +@app.callback( + dash.dependencies.Output('summary', 'children'), + [ + dash.dependencies.Input('cases', 'value'), + dash.dependencies.Input('label_radioitems', 'value'), + dash.dependencies.Input('pred_radioitems', 'value') + ] +) +def set_summary(selected_case, selected_label, selected_pred): + + proba = predictions['preds'].tolist()[int(selected_case) - 1130] + proba = np.round(proba, 4) + + if (selected_label == 'acl') and (selected_pred == 'acl'): + status = 'correctly' + + elif (selected_label == 'acl') and (selected_pred == 'normal'): + status = 'incorrectly' + + elif (selected_label == 'normal') and (selected_pred == 'acl'): + status = 'incorrectly' + + elif (selected_label == 'normal') and (selected_pred == 'normal'): + status = 'correctly' + + if selected_pred == 'acl': + summary = f'This patient, denoted by the MRI exam n°{selected_case}, is {status} diagnosed with an ACL tear with an ACL tear probability of {proba}' + elif selected_pred == 'normal': + summary = f'This patient, denoted by the MRI exam n°{selected_case}, is {status} diagnosed to be normal with an ACL tear probability of {proba}' + + + return summary + +# set summary --- end + + + + + + +# set number of cases --- begin + +@app.callback( + dash.dependencies.Output('number_of_cases', 'children'), + [ + dash.dependencies.Input('label_radioitems', 'value'), + dash.dependencies.Input('pred_radioitems', 'value') + ] +) +def set_number_cases(selected_label, selected_pred): + if (selected_label == 'acl') and (selected_pred == 'acl'): + n = predictions[(predictions['labels'] == 1) & + (predictions['preds'] >= 0.5)].shape[0] + + elif (selected_label == 'acl') and (selected_pred == 'normal'): + n = predictions[(predictions['labels'] == 1) & + (predictions['preds'] < 0.5)].shape[0] + + elif (selected_label == 'normal') and (selected_pred == 'acl'): + n = predictions[(predictions['labels'] == 0) & + (predictions['preds'] >= 0.5)].shape[0] + + elif (selected_label == 'normal') and (selected_pred == 'normal'): + n = predictions[(predictions['labels'] == 0) & + (predictions['preds'] < 0.5)].shape[0] + + msg = f"{n} MRI exams" + return msg + +# set number of cases --- end + + +# update axial slider --- begin +@app.callback( + dash.dependencies.Output('slider_axial', 'value'), + [ + dash.dependencies.Input('cases', 'value'), + ] +) +def set_slider_value_axial(selected_case): + mri = np.load(f'../data/valid/axial/{selected_case}.npy') + number_slices = mri.shape[0] + return number_slices // 2 + + +@app.callback( + dash.dependencies.Output('slider_axial', 'max'), + [ + dash.dependencies.Input('cases', 'value'), + ] +) +def set_slider_max_axial(selected_case): + mri = np.load(f'../data/valid/axial/{selected_case}.npy') + number_slices = mri.shape[0] + return number_slices - 1 + + +@app.callback( + dash.dependencies.Output('slider_axial', 'marks'), + [ + dash.dependencies.Input('cases', 'value'), + ] +) +def set_slider_marks_axial(selected_case): + mri = np.load(f'../data/valid/axial/{selected_case}.npy') + number_slices = mri.shape[0] + marks = {str(i): '{}'.format(i) for i in range(number_slices)[::2]} + return marks + +# update axial slider --- end + +# update coronal slider --- begin + + +@app.callback( + dash.dependencies.Output('slider_coronal', 'value'), + [ + dash.dependencies.Input('cases', 'value'), + ] +) +def set_slider_value_coronal(selected_case): + mri = np.load(f'../data/valid/coronal/{selected_case}.npy') + number_slices = mri.shape[0] + return number_slices // 2 + + +@app.callback( + dash.dependencies.Output('slider_coronal', 'max'), + [ + dash.dependencies.Input('cases', 'value'), + ] +) +def set_slider_max_coronal(selected_case): + mri = np.load(f'../data/valid/coronal/{selected_case}.npy') + number_slices = mri.shape[0] + return number_slices - 1 + + +@app.callback( + dash.dependencies.Output('slider_coronal', 'marks'), + [ + dash.dependencies.Input('cases', 'value'), + ] +) +def set_slider_marks_coronal(selected_case): + mri = np.load(f'../data/valid/coronal/{selected_case}.npy') + number_slices = mri.shape[0] + marks = {str(i): '{}'.format(i) for i in range(number_slices)[::2]} + return marks + +# update coronal slider --- end + +# update sagittal slider --- begin + + +@app.callback( + dash.dependencies.Output('slider_sagittal', 'value'), + [ + dash.dependencies.Input('cases', 'value'), + ] +) +def set_slider_value_sagittal(selected_case): + mri = np.load(f'../data/valid/sagittal/{selected_case}.npy') + number_slices = mri.shape[0] + return number_slices // 2 + + +@app.callback( + dash.dependencies.Output('slider_sagittal', 'max'), + [ + dash.dependencies.Input('cases', 'value'), + ] +) +def set_slider_max_sagittal(selected_case): + mri = np.load(f'../data/valid/sagittal/{selected_case}.npy') + number_slices = mri.shape[0] + return number_slices - 1 + + +@app.callback( + dash.dependencies.Output('slider_sagittal', 'marks'), + [ + dash.dependencies.Input('cases', 'value'), + ] +) +def set_slider_marks_sagittal(selected_case): + mri = np.load(f'../data/valid/sagittal/{selected_case}.npy') + number_slices = mri.shape[0] + marks = {str(i): '{}'.format(i) for i in range(number_slices)[::2]} + return marks + +# update sagittal slider --- end + +# update slider --- END + +# Axial +########################################################################## + +# write number of slice axial - begin + + +@app.callback( + dash.dependencies.Output('title_axial', 'children'), + [ + dash.dependencies.Input('cases', 'value'), + dash.dependencies.Input('slider_axial', 'value') + ] +) +def write_num_slice_axial(selected_case, selected_slice): + case = np.load(f'../data/valid/axial/{selected_case}.npy') + num_slices = case.shape[0] + title = f'Visualization of slice n°{selected_slice}/{num_slices} and its corresponding CAM' + return title +# write number of slice axial - end + + +# write score axial - begin +@app.callback( + dash.dependencies.Output('score_axial', 'children'), + [ + dash.dependencies.Input('cases', 'value'), + ] +) +def write_score_axial(selected_case): + score = predictions.iloc[int(selected_case) - 1130]['axial'] + score = np.round(score, 4) + msg = f"ACL tear proba on axial plane : {score}" + return msg +# write score axial + + +# update slice axial --- begin +@app.callback( + dash.dependencies.Output('slice_axial', 'src'), + [ + dash.dependencies.Input('cases', 'value'), + dash.dependencies.Input('slider_axial', 'value'), + + ]) +def update_slice_axial(selected_case, selected_slice): + s = np.load(f'../data/valid/axial/{selected_case}.npy')[selected_slice] + cv2.imwrite(f'./slice_axial.png', s) + encoded_image = base64.b64encode(open('./slice_axial.png', 'rb').read()) + return 'data:image/png;base64,{}'.format(encoded_image.decode()) +# update slice axial --- end +# update cam axial --- begin + + +@app.callback( + dash.dependencies.Output('cam_axial', 'src'), + [ + dash.dependencies.Input('cases', 'value'), + dash.dependencies.Input('slider_axial', 'value'), + + ]) +def update_cam_axial(selected_case, selected_slice): + selected_case = int(selected_case) - 1130 + selected_case = '0' * (4 - len(str(selected_case))) + str(selected_case) + src = f'./CAMS/axial/{selected_case}/cams/{selected_slice}.png' + encoded_image = base64.b64encode(open(src, 'rb').read()) + return 'data:image/png;base64,{}'.format(encoded_image.decode()) +# update slice axial --- end + +# Coronal +########################################################################## + +# write number of slice coronal - begin + + +@app.callback( + dash.dependencies.Output('title_coronal', 'children'), + [ + dash.dependencies.Input('cases', 'value'), + dash.dependencies.Input('slider_coronal', 'value') + ] +) +def write_num_slice_coronal(selected_case, selected_slice): + case = np.load(f'../data/valid/coronal/{selected_case}.npy') + num_slices = case.shape[0] + title = f'Visualization of slice n°{selected_slice}/{num_slices} and its corresponding CAM' + return title +# write number of slice coronal - end + +# write score coronal - begin + + +@app.callback( + dash.dependencies.Output('score_coronal', 'children'), + [ + dash.dependencies.Input('cases', 'value'), + ] +) +def write_score_coronal(selected_case): + score = predictions.iloc[int(selected_case) - 1130]['coronal'] + score = np.round(score, 4) + msg = f"ACL tear proba on coronal plane : {score}" + return msg +# write score coronal + + +# update slice coronal --- begin +@app.callback( + dash.dependencies.Output('slice_coronal', 'src'), + [ + dash.dependencies.Input('cases', 'value'), + dash.dependencies.Input('slider_coronal', 'value'), + + ]) +def update_slice_coronal(selected_case, selected_slice): + s = np.load(f'../data/valid/coronal/{selected_case}.npy')[selected_slice] + cv2.imwrite(f'./slice_coronal.png', s) + encoded_image = base64.b64encode(open('./slice_coronal.png', 'rb').read()) + return 'data:image/png;base64,{}'.format(encoded_image.decode()) +# update slice coronal --- end +# update cam coronal --- begin + + +@app.callback( + dash.dependencies.Output('cam_coronal', 'src'), + [ + dash.dependencies.Input('cases', 'value'), + dash.dependencies.Input('slider_coronal', 'value'), + + ]) +def update_cam_coronal(selected_case, selected_slice): + selected_case = int(selected_case) - 1130 + selected_case = '0' * (4 - len(str(selected_case))) + str(selected_case) + src = f'./CAMS/coronal/{selected_case}/cams/{selected_slice}.png' + encoded_image = base64.b64encode(open(src, 'rb').read()) + return 'data:image/png;base64,{}'.format(encoded_image.decode()) + +# update slice coronal --- end + +# Sagittal +########################################################################## + +# write number of slice sagittal - begin + + +@app.callback( + dash.dependencies.Output('title_sagittal', 'children'), + [ + dash.dependencies.Input('cases', 'value'), + dash.dependencies.Input('slider_sagittal', 'value') + ] +) +def write_num_slice_sagittal(selected_case, selected_slice): + case = np.load(f'../data/valid/sagittal/{selected_case}.npy') + num_slices = case.shape[0] + title = f'Visualization of slice n°{selected_slice}/{num_slices} and its corresponding CAM' + return title +# write number of slice sagittal - end + +# write score sagittal - begin + + +@app.callback( + dash.dependencies.Output('score_sagittal', 'children'), + [ + dash.dependencies.Input('cases', 'value'), + ] +) +def write_score_sagittal(selected_case): + score = predictions.iloc[int(selected_case) - 1130]['sagittal'] + score = np.round(score, 4) + msg = f"ACL tear proba on sagittal plane : {score}" + return msg +# write score sagittal + + +# update slice sagittal --- begin +@app.callback( + dash.dependencies.Output('slice_sagittal', 'src'), + [ + dash.dependencies.Input('cases', 'value'), + dash.dependencies.Input('slider_sagittal', 'value'), + + ]) +def update_slice_sagittal(selected_case, selected_slice): + s = np.load(f'../data/valid/sagittal/{selected_case}.npy')[selected_slice] + cv2.imwrite(f'./slice_sagittal.png', s) + encoded_image = base64.b64encode(open('./slice_sagittal.png', 'rb').read()) + return 'data:image/png;base64,{}'.format(encoded_image.decode()) +# update slice saigttal --- end +# update cam sagittal --- begin + + +@app.callback( + dash.dependencies.Output('cam_sagittal', 'src'), + [ + dash.dependencies.Input('cases', 'value'), + dash.dependencies.Input('slider_sagittal', 'value'), + + ]) +def update_cam_sagittal(selected_case, selected_slice): + selected_case = int(selected_case) - 1130 + selected_case = '0' * (4 - len(str(selected_case))) + str(selected_case) + src = f'./CAMS/sagittal/{selected_case}/cams/{selected_slice}.png' + encoded_image = base64.b64encode(open(src, 'rb').read()) + return 'data:image/png;base64,{}'.format(encoded_image.decode()) +# update slice coronal --- end + + +if __name__ == '__main__': + app.run_server(debug=True)