|
a |
|
b/visualizations.py |
|
|
1 |
import seaborn as sns |
|
|
2 |
import plotly.graph_objects as go |
|
|
3 |
import plotly.express as px |
|
|
4 |
import pandas as pd |
|
|
5 |
import matplotlib.pyplot as plt |
|
|
6 |
|
|
|
7 |
def plot_losses(train_loss_history, val_loss_history, num_epochs): |
|
|
8 |
|
|
|
9 |
fig = go.Figure() |
|
|
10 |
|
|
|
11 |
line_color = [px.colors.qualitative.Light24[5], px.colors.qualitative.Light24[4]] |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
fig.add_trace(go.Scatter(x=list(range(1, num_epochs+1)), y=train_loss_history, mode='lines', line=dict(color=line_color[0], width=1), name="train loss")) |
|
|
15 |
fig.add_trace(go.Scatter(x=list(range(1, num_epochs+1)), y=val_loss_history, mode='lines',line=dict(color=line_color[1], width=1), name="valdiation loss")) #opacity=0.8 |
|
|
16 |
|
|
|
17 |
fig.update_yaxes(range=[0, 1]) |
|
|
18 |
fig.update_traces(textposition='top center') |
|
|
19 |
fig.update_layout(autosize=False,width=900, height=500, title_text="SNN Loss", title_x=0.5, xaxis_title="Epoch", yaxis_title="", xaxis = dict(tickmode='linear', tick0=1, dtick=1), legend=dict(yanchor="top",xanchor="right", x=1.35, y=1),template="plotly_dark") |
|
|
20 |
fig.show() |
|
|
21 |
|
|
|
22 |
|
|
|
23 |
def plot_similarity_scores(non_matching_similarity, matching_similarity): |
|
|
24 |
|
|
|
25 |
train_num_batchs = len(non_matching_similarity) |
|
|
26 |
|
|
|
27 |
fig = go.Figure() |
|
|
28 |
|
|
|
29 |
line_color = [px.colors.qualitative.Light24[22], px.colors.qualitative.Light24[19]] |
|
|
30 |
|
|
|
31 |
|
|
|
32 |
fig.add_trace(go.Scatter(x=list(range(1, train_num_batchs+1)), y=non_matching_similarity, mode='lines', line=dict(color=line_color[0], width=1), name="unmatching categories")) |
|
|
33 |
fig.add_trace(go.Scatter(x=list(range(1, train_num_batchs+1)), y=matching_similarity, mode='lines',line=dict(color=line_color[1], width=1), name="matching categories")) #opacity=0.8 |
|
|
34 |
|
|
|
35 |
fig.update_yaxes(range=[0, 1]) |
|
|
36 |
fig.update_traces(textposition='top center') |
|
|
37 |
fig.update_layout(autosize=False,width=900, height=500, title_text="Similarity Scores", title_x=0.5, xaxis_title="Batch", yaxis_title="", legend=dict(yanchor="top",xanchor="right", x=1.35, y=1),template="plotly_dark") |
|
|
38 |
fig.show() |
|
|
39 |
|
|
|
40 |
|
|
|
41 |
|
|
|
42 |
def show_space(X, title, colors=None, color_by="", show_3D=False): |
|
|
43 |
|
|
|
44 |
if show_3D: |
|
|
45 |
dictionary = dict(zip(pd.DataFrame(X).columns, ["COMP1", "COMP2", "COMP3"])) |
|
|
46 |
temp_df = pd.DataFrame(X).rename(columns=dictionary) |
|
|
47 |
fig = px.scatter_3d(temp_df, x='COMP1', y='COMP2', z='COMP3',color=colors, template="plotly_dark", labels={"color": color_by}) |
|
|
48 |
else: |
|
|
49 |
dictionary = dict(zip(X.columns, ["COMP1", "COMP2"])) |
|
|
50 |
temp_df = pd.DataFrame(X).rename(columns=dictionary) |
|
|
51 |
fig = px.scatter(temp_df, x='COMP1', y='COMP2',color=colors, template="plotly_dark", labels={"color": color_by}) |
|
|
52 |
|
|
|
53 |
fig.update_traces(marker=dict(size=4, opacity=0.98), textposition='top center') |
|
|
54 |
fig.update_layout(title_text=title, title_x=0.5, autosize=False,width=900, height=500, legend=dict(yanchor="top",xanchor="right", x=1.1, y=1)) |
|
|
55 |
fig.show() |
|
|
56 |
#fig.write_html("file.html") |
|
|
57 |
|
|
|
58 |
|
|
|
59 |
|
|
|
60 |
def plot_confusion_matrix(mat,fig_size, labels): |
|
|
61 |
|
|
|
62 |
fig = plt.figure(figsize=(fig_size,fig_size)) |
|
|
63 |
ax= fig.add_subplot(1,1,1) |
|
|
64 |
sns.heatmap(mat, annot=True, cmap="Blues",ax = ax,fmt='g'); #annot=True to annotate cells |
|
|
65 |
|
|
|
66 |
# labels, title and ticks |
|
|
67 |
ax.set_xlabel('Predicted labels');ax.set_ylabel('True labels'); |
|
|
68 |
ax.set_title('Confusion Matrix'); |
|
|
69 |
ax.xaxis.set_ticklabels(labels); ax.yaxis.set_ticklabels(labels); |
|
|
70 |
plt.setp(ax.get_yticklabels(), rotation=30, horizontalalignment='right') |
|
|
71 |
plt.setp(ax.get_xticklabels(), rotation=30, horizontalalignment='right') |
|
|
72 |
plt.show() |