In [1]:
import plotly.express as px
import matplotlib.pyplot as plt
import pandas as pd

In [4]:
def boxplot(output_name, network, y_min=None, y_max=None):
    df_dice = pd.read_csv(f"outputs/{network}/{output_name}/infer_dice_class.csv")
    df_iou  = pd.read_csv(f"outputs/{network}/{output_name}/infer_iou_class.csv")

    df_dice['type'] = 'dice'
    df_iou['type'] = 'iou'

    df_combined = pd.concat([df_dice, df_iou])
    df_combined.reset_index(drop=True, inplace=True)
    df_final = pd.melt(df_combined, id_vars=['type'], var_name='classes', value_name='scores')
    df_final.sort_values(['type', 'classes'], inplace=True)
    df_final.reset_index(drop=True, inplace=True)

    fig = px.box(df_final, x="classes", y="scores", color="type")
    fig.update_traces(quartilemethod="exclusive")
    fig.update_layout(
        width=1000, 
        height=500,
        font=dict(
            size=18 
        )
    )
    if y_min is not None and y_max is not None:
        fig.update_yaxes(range=[y_min, y_max])
    fig.show()

### VHSCDD $(256 \times 256)$

In [8]:
network = 'RotCAtt_TransUNet_plusplus'
output_name = 'VHSCDD_RotCAtt_TransUNet_plusplus_bs24_ps16_epo600_hw256_ly4'
boxplot(output_name, network, y_min=0.65, y_max=1)

### MMWHS $(256 \times 256)$

In [7]:
network = 'RotCAtt_TransUNet_plusplus'
output_name = 'MMWHS_RotCAtt_TransUNet_plusplus_bs24_ps16_epo600_hw256_ly4'
boxplot(output_name, network, y_min=0.65, y_max=1)

### ImageCHD $(256 \times 256)$

In [6]:
network = 'RotCAtt_TransUNet_plusplus'
output_name = 'Imagechd_RotCAtt_TransUNet_plusplus_bs24_ps16_epo600_hw256_ly4'
boxplot(output_name, network, y_min=0.8, y_max=1)

In [11]:
network = 'RotCAtt_TransUNet_plusplus'
output_name = 'Synapse_RotCAtt_TransUNet_plusplus_bs24_ps16_epo600_hw256_ly4'
boxplot(output_name, network)