|
a |
|
b/plot_sactter.py |
|
|
1 |
import pandas as pd |
|
|
2 |
import matplotlib.pyplot as plt |
|
|
3 |
from mpl_toolkits.mplot3d import Axes3D |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
def plot_scatter(latent_code, output_path, |
|
|
7 |
label_file='data/PANCAN/GDC-PANCAN_both_samples_tumour_type.tsv', |
|
|
8 |
colour_file='data/TCGA_colors_obvious.tsv', latent_code_dim=2, have_label=True): |
|
|
9 |
if latent_code_dim <= 3: |
|
|
10 |
if latent_code_dim == 3: |
|
|
11 |
# Plot the 3D scatter graph of latent space |
|
|
12 |
if have_label: |
|
|
13 |
# Set sample label |
|
|
14 |
disease_id = pd.read_csv(label_file, sep='\t', index_col=0) |
|
|
15 |
latent_code_label = pd.merge(latent_code, disease_id, left_index=True, right_index=True) |
|
|
16 |
colour_setting = pd.read_csv(colour_file, sep='\t') |
|
|
17 |
fig = plt.figure(figsize=(8, 5.5)) |
|
|
18 |
ax = fig.add_subplot(111, projection='3d') |
|
|
19 |
for index in range(len(colour_setting)): |
|
|
20 |
code = colour_setting.iloc[index, 1] |
|
|
21 |
colour = colour_setting.iloc[index, 0] |
|
|
22 |
if code in latent_code_label.iloc[:, latent_code_dim].unique(): |
|
|
23 |
latent_code_label_part = latent_code_label[latent_code_label.iloc[:, latent_code_dim] == code] |
|
|
24 |
ax.scatter(latent_code_label_part.iloc[:, 0], latent_code_label_part.iloc[:, 1], |
|
|
25 |
latent_code_label_part.iloc[:, 2], s=2, marker='o', alpha=0.8, c=colour, label=code) |
|
|
26 |
ax.legend(ncol=2, markerscale=4, bbox_to_anchor=(1, 0.9), loc='upper left', frameon=False) |
|
|
27 |
else: |
|
|
28 |
fig = plt.figure() |
|
|
29 |
ax = fig.add_subplot(111, projection='3d') |
|
|
30 |
ax.scatter(latent_code.iloc[:, 0], latent_code.iloc[:, 1], latent_code.iloc[:, 2], s=2, marker='o', |
|
|
31 |
alpha=0.8) |
|
|
32 |
ax.set_xlabel('First Latent Dimension') |
|
|
33 |
ax.set_ylabel('Second Latent Dimension') |
|
|
34 |
ax.set_zlabel('Third Latent Dimension') |
|
|
35 |
elif latent_code_dim == 2: |
|
|
36 |
if have_label: |
|
|
37 |
# Set sample label |
|
|
38 |
disease_id = pd.read_csv(label_file, sep='\t', index_col=0) |
|
|
39 |
latent_code_label = pd.merge(latent_code, disease_id, left_index=True, right_index=True) |
|
|
40 |
colour_setting = pd.read_csv(colour_file, sep='\t') |
|
|
41 |
plt.figure(figsize=(8, 5.5)) |
|
|
42 |
for index in range(len(colour_setting)): |
|
|
43 |
code = colour_setting.iloc[index, 1] |
|
|
44 |
colour = colour_setting.iloc[index, 0] |
|
|
45 |
if code in latent_code_label.iloc[:, latent_code_dim].unique(): |
|
|
46 |
latent_code_label_part = latent_code_label[latent_code_label.iloc[:, latent_code_dim] == code] |
|
|
47 |
plt.scatter(latent_code_label_part.iloc[:, 0], latent_code_label_part.iloc[:, 1], s=2, |
|
|
48 |
marker='o', alpha=0.8, c=colour, label=code) |
|
|
49 |
plt.legend(ncol=2, markerscale=4, bbox_to_anchor=(1, 1), loc='upper left', frameon=False) |
|
|
50 |
else: |
|
|
51 |
plt.scatter(latent_code.iloc[:, 0], latent_code.iloc[:, 1], s=2, marker='o', alpha=0.8) |
|
|
52 |
plt.xlabel('First Latent Dimension') |
|
|
53 |
plt.ylabel('Second Latent Dimension') |
|
|
54 |
input_file_name = output_path.split('/')[-1] |
|
|
55 |
fig_path = 'results/' + input_file_name + str(latent_code_dim) + 'D_fig.png' |
|
|
56 |
fig_path_svg = 'results/' + input_file_name + str(latent_code_dim) + 'D_fig.svg' |
|
|
57 |
plt.tight_layout() |
|
|
58 |
plt.savefig(fig_path, dpi=300) |
|
|
59 |
plt.savefig(fig_path_svg) |