Diff of /visualizations.py [000000] .. [87f2bb]

Switch to unified view

a b/visualizations.py
1
import seaborn as sns
2
import plotly.graph_objects as go
3
import plotly.express as px
4
import pandas as pd
5
import matplotlib.pyplot as plt
6
7
def plot_losses(train_loss_history, val_loss_history, num_epochs):
8
9
    fig = go.Figure()
10
11
    line_color = [px.colors.qualitative.Light24[5], px.colors.qualitative.Light24[4]]
12
13
14
    fig.add_trace(go.Scatter(x=list(range(1, num_epochs+1)), y=train_loss_history, mode='lines', line=dict(color=line_color[0], width=1), name="train loss"))
15
    fig.add_trace(go.Scatter(x=list(range(1, num_epochs+1)), y=val_loss_history, mode='lines',line=dict(color=line_color[1], width=1), name="valdiation loss")) #opacity=0.8
16
17
    fig.update_yaxes(range=[0, 1])
18
    fig.update_traces(textposition='top center')
19
    fig.update_layout(autosize=False,width=900, height=500, title_text="SNN Loss", title_x=0.5, xaxis_title="Epoch", yaxis_title="", xaxis = dict(tickmode='linear', tick0=1, dtick=1), legend=dict(yanchor="top",xanchor="right", x=1.35, y=1),template="plotly_dark")
20
    fig.show()
21
22
23
def plot_similarity_scores(non_matching_similarity, matching_similarity):
24
25
    train_num_batchs = len(non_matching_similarity)
26
27
    fig = go.Figure()
28
29
    line_color = [px.colors.qualitative.Light24[22], px.colors.qualitative.Light24[19]]
30
31
32
    fig.add_trace(go.Scatter(x=list(range(1, train_num_batchs+1)), y=non_matching_similarity, mode='lines', line=dict(color=line_color[0], width=1), name="unmatching categories"))
33
    fig.add_trace(go.Scatter(x=list(range(1, train_num_batchs+1)), y=matching_similarity, mode='lines',line=dict(color=line_color[1], width=1), name="matching categories")) #opacity=0.8
34
35
    fig.update_yaxes(range=[0, 1])
36
    fig.update_traces(textposition='top center')
37
    fig.update_layout(autosize=False,width=900, height=500, title_text="Similarity Scores", title_x=0.5, xaxis_title="Batch", yaxis_title="", legend=dict(yanchor="top",xanchor="right", x=1.35, y=1),template="plotly_dark")
38
    fig.show()
39
40
41
42
def show_space(X, title, colors=None, color_by="", show_3D=False):
43
44
  if show_3D:
45
    dictionary = dict(zip(pd.DataFrame(X).columns, ["COMP1", "COMP2", "COMP3"]))
46
    temp_df = pd.DataFrame(X).rename(columns=dictionary)
47
    fig = px.scatter_3d(temp_df, x='COMP1', y='COMP2', z='COMP3',color=colors, template="plotly_dark", labels={"color": color_by})
48
  else:
49
    dictionary = dict(zip(X.columns, ["COMP1", "COMP2"]))
50
    temp_df = pd.DataFrame(X).rename(columns=dictionary)
51
    fig = px.scatter(temp_df, x='COMP1', y='COMP2',color=colors, template="plotly_dark", labels={"color": color_by})
52
53
  fig.update_traces(marker=dict(size=4, opacity=0.98), textposition='top center')
54
  fig.update_layout(title_text=title, title_x=0.5, autosize=False,width=900, height=500, legend=dict(yanchor="top",xanchor="right", x=1.1, y=1))
55
  fig.show()
56
  #fig.write_html("file.html")
57
58
59
60
def plot_confusion_matrix(mat,fig_size, labels):
61
62
    fig = plt.figure(figsize=(fig_size,fig_size))
63
    ax= fig.add_subplot(1,1,1)
64
    sns.heatmap(mat, annot=True, cmap="Blues",ax = ax,fmt='g'); #annot=True to annotate cells
65
66
    # labels, title and ticks
67
    ax.set_xlabel('Predicted labels');ax.set_ylabel('True labels');
68
    ax.set_title('Confusion Matrix');
69
    ax.xaxis.set_ticklabels(labels); ax.yaxis.set_ticklabels(labels);
70
    plt.setp(ax.get_yticklabels(), rotation=30, horizontalalignment='right')
71
    plt.setp(ax.get_xticklabels(), rotation=30, horizontalalignment='right')
72
    plt.show()