--- a
+++ b/dash/main.py
@@ -0,0 +1,767 @@
+# Introducing callbacks
+
+# -*- coding: utf-8 -*-
+import base64
+import time
+import dash
+import dash_core_components as dcc
+import dash_html_components as html
+import dash_daq as daq
+
+
+import pandas as pd
+import numpy as np
+import numpy as np
+
+import cv2
+
+
+# prepare the data -- begin
+
+cases = pd.read_csv('../data/valid-acl.csv',
+                    header=None,
+                    names=['Case', 'Abnormal'],
+                    dtype={'Case': str, 'Abnormal': np.int64}
+                    )
+case_list = cases['Case'].tolist()
+
+
+predictions = pd.read_csv('./val_data.csv')
+
+
+# prepare the data -- end
+
+
+app = dash.Dash(show_undo_redo=False)
+
+# Boostrap CSS.
+
+app.css.append_css({'external_url': 'https://codepen.io/amyoshino/pen/jzXypZ.css'})  # noqa: E501
+#app.css.append_css({'external_url': "https://stackpath.bootstrapcdn.com/bootstrap/3.4.1/css/bootstrap.min.css"})  # noqa: E501
+
+
+
+app.layout = html.Div(
+    html.Div([
+        html.Div(
+            
+                html.H1(children='Interpretation of MRNet models through Class Activation Maps (CAM)', 
+                        className='twelve columns',
+                        style={'text-align': 'center'}                      
+                    )
+            ,
+            className="row",
+        ),
+
+        html.Div(
+            [
+                html.Div(
+                    [
+                        html.P('Select a medical case (i.e. a patient)'),
+                        html.Div([
+                            dcc.Dropdown(
+                                id='cases',
+                                options=[
+                                    {'label': case, 'value': case} for case in case_list
+                                ],
+                                placeholder="Pick a case",
+                                clearable=False
+                            )
+                        ],
+                            style={'margin-bottom': 20}
+                        ),
+                    ],
+                    className='three columns',
+                    style={'margin-top': '10'}
+                ),
+                html.Div([
+                    html.Div([
+                        html.Div([
+                            html.P('Select true labels :'),
+                            dcc.RadioItems(
+                                id="label_radioitems",
+                                options=[
+                                    {'label': 'Positive (ACL tear)', 'value': 'acl'},
+                                    {'label': 'Negative (Normal)', 'value': 'normal'},
+                                ],
+                                value='acl',
+                                labelStyle={'display': 'inline-block'}
+                            ),
+                        ],
+                            style={'float': 'left', 'width': '45%'}
+                        ),
+
+                        html.Div([
+                            html.P('Select predicted labels :'),
+                            dcc.RadioItems(
+                                id="pred_radioitems",
+                                options=[
+                                    {'label': 'Positive (ACL tear)', 'value': 'acl'},
+                                    {'label': 'Negative (Normal)', 'value': 'normal'},
+                                ],
+                                value='acl',
+                                labelStyle={'display': 'inline-block'}
+                            ),
+                        ]
+                        ),
+                    ]),
+                ],
+                    className='six columns'
+                ),
+                html.Div([
+                    html.Div(id="number_of_cases"),
+                    html.Span(
+                            id="label_badge", 
+                            className="badge badge-success badge-large",
+                            style={'font-size': '15px'}
+                            ),
+                ],
+                    className='three columns'
+
+                )
+
+            ], className="row"
+        ),
+
+        html.Div([
+
+            html.Div([
+                html.P(id='summary', style={'font-size': '20px'}),
+                html.Div([
+                        html.Div('This probability is a weighted average of the three probabilities of tears over each plane', style={'float': 'left', 'font-size': '20px'}),
+                        html.Div('Slide over the slices of each MRI to inspect highlighted regions of tear as depicted by CAMs', style={'float': 'left', 'font-size': '20px'}),
+
+                    ],
+                    # style={'text-align': 'center'}
+                    )
+                    ],
+                    className="twelve columns"),
+
+        ],
+        className='row'
+        ),
+
+        html.Hr(),
+
+        html.Div(
+            [
+                html.Div([
+                    html.Div([
+                        dcc.Slider(id='slider_axial', updatemode='drag')
+                    ],
+                        style={'margin-right': '5px'}
+
+                    ),
+                    html.Hr(),
+                    html.P(id="score_axial", style={'text-align': 'center'}),
+                    html.Div([
+                        html.Div([
+                            html.Img(
+                                id="slice_axial",
+                            ),
+                        ],
+                            style={'float': 'left', 'margin-right': '5px'}
+                        ),
+                        html.Div([
+                            html.Img(
+                                id="cam_axial",
+                            ),
+                        ],
+                        )
+                    ],
+
+                    ),
+                    html.P(id="title_axial", style={'text-align': 'center'})
+                ],
+                    className="four columns"
+                ),
+                html.Div([
+                    html.Div([
+                        dcc.Slider(id='slider_coronal', updatemode='drag')
+                    ],
+                        style={'margin-right': '5px'}
+
+                    ),
+                    html.Hr(),
+                    html.P(id="score_coronal", style={'text-align': 'center'}),
+                    html.Div([
+                        html.Div([
+                            html.Img(
+                                id="slice_coronal",
+                            ),
+                        ],
+                            style={'float': 'left', 'margin-right': '5px'}
+                        ),
+                        html.Div([
+                            html.Img(
+                                id="cam_coronal",
+                            ),
+                        ],
+                        )
+                    ],
+
+                    ),
+                    html.P(id="title_coronal", style={'text-align': 'center'})
+                ],
+                    className="four columns"
+                ),
+                html.Div([
+                    html.Div([
+                        dcc.Slider(id='slider_sagittal', updatemode='drag')
+                    ]),
+                    html.Hr(),
+                    html.P(id="score_sagittal", style={
+                           'text-align': 'center'}),
+                    html.Div([
+                        html.Div([
+                            html.Img(
+                                id="slice_sagittal",
+                            ),
+                        ],
+                            style={'float': 'left', 'margin-right': '5px'}
+                        ),
+                        html.Div([
+                            html.Img(
+                                id="cam_sagittal",
+                            ),
+                        ],
+                        )
+                    ],
+
+                    ),
+                    html.P(id="title_sagittal", style={'text-align': 'center'})
+
+
+                ],
+                    className="four columns"
+                ),
+
+            ],
+            className='row'
+        )
+    ], className='twelve columns')
+)
+
+
+# select label --- begin
+@app.callback(
+    dash.dependencies.Output('cases', 'options'),
+    [
+        dash.dependencies.Input('label_radioitems', 'value'),
+        dash.dependencies.Input('pred_radioitems', 'value'),
+    ]
+)
+def set_label(selected_label, selected_pred):
+    if (selected_label == 'acl') and (selected_pred == 'acl'):
+        filtered_cases = predictions[(predictions['labels'] == 1) & 
+                                     (predictions['preds'] >= 0.5)].index.tolist()
+
+    elif (selected_label == 'acl') and (selected_pred == 'normal'):
+        filtered_cases = predictions[(predictions['labels'] == 1) & 
+                                     (predictions['preds'] < 0.5)].index.tolist()
+
+    elif (selected_label == 'normal') and (selected_pred == 'acl'):
+        filtered_cases = predictions[(predictions['labels'] == 0) & 
+                                     (predictions['preds'] >= 0.5)].index.tolist()
+
+    elif (selected_label == 'normal') and (selected_pred == 'normal'):
+        filtered_cases = predictions[(predictions['labels'] == 0) & 
+                                     (predictions['preds'] < 0.5)].index.tolist()                                 
+
+    filtered_cases = [c + 1130 for c in filtered_cases]
+    options = [{'label': fc, 'value': fc} for fc in filtered_cases]
+    return options
+# select label --- end
+
+# set badge label --- begin
+
+@app.callback(
+    dash.dependencies.Output('label_badge', 'children'),
+    [
+        dash.dependencies.Input('label_radioitems', 'value'),
+        dash.dependencies.Input('pred_radioitems', 'value'),
+    ]
+)
+def set_badge_label(selected_label, selected_pred):
+    if (selected_label == 'acl') and (selected_pred == 'acl'):
+        text = 'true positive case'
+
+    elif (selected_label == 'acl') and (selected_pred == 'normal'):
+        text = 'false negative case'
+
+    elif (selected_label == 'normal') and (selected_pred == 'acl'):
+        text = 'false positive case'
+
+    elif (selected_label == 'normal') and (selected_pred == 'normal'):
+        text = 'true negative case'
+
+    return text
+
+# set badge label --- end
+
+# set badge color --- begin
+
+@app.callback(
+    dash.dependencies.Output('label_badge', 'className'),
+    [
+        dash.dependencies.Input('label_radioitems', 'value'),
+        dash.dependencies.Input('pred_radioitems', 'value'),
+    ]
+)
+def set_badge_color(selected_label, selected_pred):
+    if (selected_label == 'acl') and (selected_pred == 'acl'):
+        className = 'badge badge-success'
+
+    elif (selected_label == 'acl') and (selected_pred == 'normal'):
+        className = 'badge badge-error'
+
+    elif (selected_label == 'normal') and (selected_pred == 'acl'):
+        className = 'badge badge-error'
+
+    elif (selected_label == 'normal') and (selected_pred == 'normal'):
+        className = 'badge badge-success'
+
+    return className
+
+# set badge color --- end
+
+# set a case value --- begin
+
+@app.callback(
+    dash.dependencies.Output('cases', 'value'),
+    [
+        dash.dependencies.Input('label_radioitems', 'value'),
+        dash.dependencies.Input('pred_radioitems', 'value'),
+    ]
+)
+def set_badge_color(selected_label, selected_pred):
+    if (selected_label == 'acl') and (selected_pred == 'acl'):
+        filtered_cases = predictions[(predictions['labels'] == 1) & 
+                                     (predictions['preds'] >= 0.5)].index.tolist()
+
+    elif (selected_label == 'acl') and (selected_pred == 'normal'):
+        filtered_cases = predictions[(predictions['labels'] == 1) & 
+                                     (predictions['preds'] < 0.5)].index.tolist()
+
+    elif (selected_label == 'normal') and (selected_pred == 'acl'):
+        filtered_cases = predictions[(predictions['labels'] == 0) & 
+                                     (predictions['preds'] >= 0.5)].index.tolist()
+
+    elif (selected_label == 'normal') and (selected_pred == 'normal'):
+        filtered_cases = predictions[(predictions['labels'] == 0) & 
+                                     (predictions['preds'] < 0.5)].index.tolist()                                 
+
+    filtered_cases = [c + 1130 for c in filtered_cases]
+    case_value = np.random.choice(filtered_cases)
+    return case_value
+
+# set a case value --- end
+
+
+# set summary --- begin
+
+@app.callback(
+    dash.dependencies.Output('summary', 'children'),
+    [
+        dash.dependencies.Input('cases', 'value'),
+        dash.dependencies.Input('label_radioitems', 'value'),
+        dash.dependencies.Input('pred_radioitems', 'value')
+    ]
+)
+def set_summary(selected_case, selected_label, selected_pred):
+   
+    proba = predictions['preds'].tolist()[int(selected_case) - 1130]
+    proba = np.round(proba, 4)
+
+    if (selected_label == 'acl') and (selected_pred == 'acl'):
+        status = 'correctly'
+
+    elif (selected_label == 'acl') and (selected_pred == 'normal'):
+        status = 'incorrectly'
+
+    elif (selected_label == 'normal') and (selected_pred == 'acl'):
+        status = 'incorrectly'
+
+    elif (selected_label == 'normal') and (selected_pred == 'normal'):
+        status = 'correctly'
+
+    if selected_pred == 'acl':
+        summary = f'This patient, denoted by the MRI exam n°{selected_case}, is {status} diagnosed with an ACL tear with an ACL tear probability of {proba}'
+    elif selected_pred == 'normal':
+        summary = f'This patient, denoted by the MRI exam n°{selected_case}, is {status} diagnosed to be normal with an ACL tear probability of {proba}'
+ 
+
+    return summary
+
+# set summary --- end
+
+
+
+
+
+
+# set number of cases --- begin
+
+@app.callback(
+    dash.dependencies.Output('number_of_cases', 'children'),
+    [
+        dash.dependencies.Input('label_radioitems', 'value'),
+        dash.dependencies.Input('pred_radioitems', 'value')
+    ]
+)
+def set_number_cases(selected_label, selected_pred):
+    if (selected_label == 'acl') and (selected_pred == 'acl'):
+        n = predictions[(predictions['labels'] == 1) & 
+                                     (predictions['preds'] >= 0.5)].shape[0]
+
+    elif (selected_label == 'acl') and (selected_pred == 'normal'):
+        n = predictions[(predictions['labels'] == 1) & 
+                                     (predictions['preds'] < 0.5)].shape[0]
+
+    elif (selected_label == 'normal') and (selected_pred == 'acl'):
+        n = predictions[(predictions['labels'] == 0) & 
+                                     (predictions['preds'] >= 0.5)].shape[0]
+
+    elif (selected_label == 'normal') and (selected_pred == 'normal'):
+        n = predictions[(predictions['labels'] == 0) & 
+                                     (predictions['preds'] < 0.5)].shape[0]                                 
+
+    msg = f"{n} MRI exams"
+    return msg
+
+# set number of cases --- end
+
+
+# update axial slider --- begin
+@app.callback(
+    dash.dependencies.Output('slider_axial', 'value'),
+    [
+        dash.dependencies.Input('cases', 'value'),
+    ]
+)
+def set_slider_value_axial(selected_case):
+    mri = np.load(f'../data/valid/axial/{selected_case}.npy')
+    number_slices = mri.shape[0]
+    return number_slices // 2
+
+
+@app.callback(
+    dash.dependencies.Output('slider_axial', 'max'),
+    [
+        dash.dependencies.Input('cases', 'value'),
+    ]
+)
+def set_slider_max_axial(selected_case):
+    mri = np.load(f'../data/valid/axial/{selected_case}.npy')
+    number_slices = mri.shape[0]
+    return number_slices - 1
+
+
+@app.callback(
+    dash.dependencies.Output('slider_axial', 'marks'),
+    [
+        dash.dependencies.Input('cases', 'value'),
+    ]
+)
+def set_slider_marks_axial(selected_case):
+    mri = np.load(f'../data/valid/axial/{selected_case}.npy')
+    number_slices = mri.shape[0]
+    marks = {str(i): '{}'.format(i) for i in range(number_slices)[::2]}
+    return marks
+
+# update axial slider --- end
+
+# update coronal slider --- begin
+
+
+@app.callback(
+    dash.dependencies.Output('slider_coronal', 'value'),
+    [
+        dash.dependencies.Input('cases', 'value'),
+    ]
+)
+def set_slider_value_coronal(selected_case):
+    mri = np.load(f'../data/valid/coronal/{selected_case}.npy')
+    number_slices = mri.shape[0]
+    return number_slices // 2
+
+
+@app.callback(
+    dash.dependencies.Output('slider_coronal', 'max'),
+    [
+        dash.dependencies.Input('cases', 'value'),
+    ]
+)
+def set_slider_max_coronal(selected_case):
+    mri = np.load(f'../data/valid/coronal/{selected_case}.npy')
+    number_slices = mri.shape[0]
+    return number_slices - 1
+
+
+@app.callback(
+    dash.dependencies.Output('slider_coronal', 'marks'),
+    [
+        dash.dependencies.Input('cases', 'value'),
+    ]
+)
+def set_slider_marks_coronal(selected_case):
+    mri = np.load(f'../data/valid/coronal/{selected_case}.npy')
+    number_slices = mri.shape[0]
+    marks = {str(i): '{}'.format(i) for i in range(number_slices)[::2]}
+    return marks
+
+# update coronal slider --- end
+
+# update sagittal slider --- begin
+
+
+@app.callback(
+    dash.dependencies.Output('slider_sagittal', 'value'),
+    [
+        dash.dependencies.Input('cases', 'value'),
+    ]
+)
+def set_slider_value_sagittal(selected_case):
+    mri = np.load(f'../data/valid/sagittal/{selected_case}.npy')
+    number_slices = mri.shape[0]
+    return number_slices // 2
+
+
+@app.callback(
+    dash.dependencies.Output('slider_sagittal', 'max'),
+    [
+        dash.dependencies.Input('cases', 'value'),
+    ]
+)
+def set_slider_max_sagittal(selected_case):
+    mri = np.load(f'../data/valid/sagittal/{selected_case}.npy')
+    number_slices = mri.shape[0]
+    return number_slices - 1
+
+
+@app.callback(
+    dash.dependencies.Output('slider_sagittal', 'marks'),
+    [
+        dash.dependencies.Input('cases', 'value'),
+    ]
+)
+def set_slider_marks_sagittal(selected_case):
+    mri = np.load(f'../data/valid/sagittal/{selected_case}.npy')
+    number_slices = mri.shape[0]
+    marks = {str(i): '{}'.format(i) for i in range(number_slices)[::2]}
+    return marks
+
+# update sagittal slider --- end
+
+# update slider --- END
+
+# Axial
+##########################################################################
+
+# write number of slice axial - begin
+
+
+@app.callback(
+    dash.dependencies.Output('title_axial', 'children'),
+    [
+        dash.dependencies.Input('cases', 'value'),
+        dash.dependencies.Input('slider_axial', 'value')
+    ]
+)
+def write_num_slice_axial(selected_case, selected_slice):
+    case = np.load(f'../data/valid/axial/{selected_case}.npy')
+    num_slices = case.shape[0]
+    title = f'Visualization of slice n°{selected_slice}/{num_slices} and its corresponding CAM'
+    return title
+# write number of slice axial - end
+
+
+# write score axial - begin
+@app.callback(
+    dash.dependencies.Output('score_axial', 'children'),
+    [
+        dash.dependencies.Input('cases', 'value'),
+    ]
+)
+def write_score_axial(selected_case):
+    score = predictions.iloc[int(selected_case) - 1130]['axial']
+    score = np.round(score, 4)
+    msg = f"ACL tear proba on axial plane : {score}"
+    return msg
+# write score axial
+
+
+# update slice axial --- begin
+@app.callback(
+    dash.dependencies.Output('slice_axial', 'src'),
+    [
+        dash.dependencies.Input('cases', 'value'),
+        dash.dependencies.Input('slider_axial', 'value'),
+
+    ])
+def update_slice_axial(selected_case, selected_slice):
+    s = np.load(f'../data/valid/axial/{selected_case}.npy')[selected_slice]
+    cv2.imwrite(f'./slice_axial.png', s)
+    encoded_image = base64.b64encode(open('./slice_axial.png', 'rb').read())
+    return 'data:image/png;base64,{}'.format(encoded_image.decode())
+# update slice axial --- end
+# update cam axial --- begin
+
+
+@app.callback(
+    dash.dependencies.Output('cam_axial', 'src'),
+    [
+        dash.dependencies.Input('cases', 'value'),
+        dash.dependencies.Input('slider_axial', 'value'),
+
+    ])
+def update_cam_axial(selected_case, selected_slice):
+    selected_case = int(selected_case) - 1130
+    selected_case = '0' * (4 - len(str(selected_case))) + str(selected_case)
+    src = f'./CAMS/axial/{selected_case}/cams/{selected_slice}.png'
+    encoded_image = base64.b64encode(open(src, 'rb').read())
+    return 'data:image/png;base64,{}'.format(encoded_image.decode())
+# update slice axial --- end
+
+# Coronal
+##########################################################################
+
+# write number of slice coronal - begin
+
+
+@app.callback(
+    dash.dependencies.Output('title_coronal', 'children'),
+    [
+        dash.dependencies.Input('cases', 'value'),
+        dash.dependencies.Input('slider_coronal', 'value')
+    ]
+)
+def write_num_slice_coronal(selected_case, selected_slice):
+    case = np.load(f'../data/valid/coronal/{selected_case}.npy')
+    num_slices = case.shape[0]
+    title = f'Visualization of slice n°{selected_slice}/{num_slices} and its corresponding CAM'
+    return title
+# write number of slice coronal - end
+
+# write score coronal - begin
+
+
+@app.callback(
+    dash.dependencies.Output('score_coronal', 'children'),
+    [
+        dash.dependencies.Input('cases', 'value'),
+    ]
+)
+def write_score_coronal(selected_case):
+    score = predictions.iloc[int(selected_case) - 1130]['coronal']
+    score = np.round(score, 4)
+    msg = f"ACL tear proba on coronal plane : {score}"
+    return msg
+# write score coronal
+
+
+# update slice coronal --- begin
+@app.callback(
+    dash.dependencies.Output('slice_coronal', 'src'),
+    [
+        dash.dependencies.Input('cases', 'value'),
+        dash.dependencies.Input('slider_coronal', 'value'),
+
+    ])
+def update_slice_coronal(selected_case, selected_slice):
+    s = np.load(f'../data/valid/coronal/{selected_case}.npy')[selected_slice]
+    cv2.imwrite(f'./slice_coronal.png', s)
+    encoded_image = base64.b64encode(open('./slice_coronal.png', 'rb').read())
+    return 'data:image/png;base64,{}'.format(encoded_image.decode())
+# update slice coronal --- end
+# update cam coronal --- begin
+
+
+@app.callback(
+    dash.dependencies.Output('cam_coronal', 'src'),
+    [
+        dash.dependencies.Input('cases', 'value'),
+        dash.dependencies.Input('slider_coronal', 'value'),
+
+    ])
+def update_cam_coronal(selected_case, selected_slice):
+    selected_case = int(selected_case) - 1130
+    selected_case = '0' * (4 - len(str(selected_case))) + str(selected_case)
+    src = f'./CAMS/coronal/{selected_case}/cams/{selected_slice}.png'
+    encoded_image = base64.b64encode(open(src, 'rb').read())
+    return 'data:image/png;base64,{}'.format(encoded_image.decode())
+
+# update slice coronal --- end
+
+# Sagittal
+##########################################################################
+
+# write number of slice sagittal - begin
+
+
+@app.callback(
+    dash.dependencies.Output('title_sagittal', 'children'),
+    [
+        dash.dependencies.Input('cases', 'value'),
+        dash.dependencies.Input('slider_sagittal', 'value')
+    ]
+)
+def write_num_slice_sagittal(selected_case, selected_slice):
+    case = np.load(f'../data/valid/sagittal/{selected_case}.npy')
+    num_slices = case.shape[0]
+    title = f'Visualization of slice n°{selected_slice}/{num_slices} and its corresponding CAM'
+    return title
+# write number of slice sagittal - end
+
+# write score sagittal - begin
+
+
+@app.callback(
+    dash.dependencies.Output('score_sagittal', 'children'),
+    [
+        dash.dependencies.Input('cases', 'value'),
+    ]
+)
+def write_score_sagittal(selected_case):
+    score = predictions.iloc[int(selected_case) - 1130]['sagittal']
+    score = np.round(score, 4)
+    msg = f"ACL tear proba on sagittal plane : {score}"
+    return msg
+# write score sagittal
+
+
+# update slice sagittal --- begin
+@app.callback(
+    dash.dependencies.Output('slice_sagittal', 'src'),
+    [
+        dash.dependencies.Input('cases', 'value'),
+        dash.dependencies.Input('slider_sagittal', 'value'),
+
+    ])
+def update_slice_sagittal(selected_case, selected_slice):
+    s = np.load(f'../data/valid/sagittal/{selected_case}.npy')[selected_slice]
+    cv2.imwrite(f'./slice_sagittal.png', s)
+    encoded_image = base64.b64encode(open('./slice_sagittal.png', 'rb').read())
+    return 'data:image/png;base64,{}'.format(encoded_image.decode())
+# update slice saigttal --- end
+# update cam sagittal --- begin
+
+
+@app.callback(
+    dash.dependencies.Output('cam_sagittal', 'src'),
+    [
+        dash.dependencies.Input('cases', 'value'),
+        dash.dependencies.Input('slider_sagittal', 'value'),
+
+    ])
+def update_cam_sagittal(selected_case, selected_slice):
+    selected_case = int(selected_case) - 1130
+    selected_case = '0' * (4 - len(str(selected_case))) + str(selected_case)
+    src = f'./CAMS/sagittal/{selected_case}/cams/{selected_slice}.png'
+    encoded_image = base64.b64encode(open(src, 'rb').read())
+    return 'data:image/png;base64,{}'.format(encoded_image.decode())
+# update slice coronal --- end
+
+
+if __name__ == '__main__':
+    app.run_server(debug=True)