Diff of /plot_sactter.py [000000] .. [2d53aa]

Switch to unified view

a b/plot_sactter.py
1
import pandas as pd
2
import matplotlib.pyplot as plt
3
from mpl_toolkits.mplot3d import Axes3D
4
5
6
def plot_scatter(latent_code, output_path,
7
                 label_file='data/PANCAN/GDC-PANCAN_both_samples_tumour_type.tsv',
8
                 colour_file='data/TCGA_colors_obvious.tsv', latent_code_dim=2, have_label=True):
9
    if latent_code_dim <= 3:
10
        if latent_code_dim == 3:
11
            # Plot the 3D scatter graph of latent space
12
            if have_label:
13
                # Set sample label
14
                disease_id = pd.read_csv(label_file, sep='\t', index_col=0)
15
                latent_code_label = pd.merge(latent_code, disease_id, left_index=True, right_index=True)
16
                colour_setting = pd.read_csv(colour_file, sep='\t')
17
                fig = plt.figure(figsize=(8, 5.5))
18
                ax = fig.add_subplot(111, projection='3d')
19
                for index in range(len(colour_setting)):
20
                    code = colour_setting.iloc[index, 1]
21
                    colour = colour_setting.iloc[index, 0]
22
                    if code in latent_code_label.iloc[:, latent_code_dim].unique():
23
                        latent_code_label_part = latent_code_label[latent_code_label.iloc[:, latent_code_dim] == code]
24
                        ax.scatter(latent_code_label_part.iloc[:, 0], latent_code_label_part.iloc[:, 1],
25
                                   latent_code_label_part.iloc[:, 2], s=2, marker='o', alpha=0.8, c=colour, label=code)
26
                ax.legend(ncol=2, markerscale=4, bbox_to_anchor=(1, 0.9), loc='upper left', frameon=False)
27
            else:
28
                fig = plt.figure()
29
                ax = fig.add_subplot(111, projection='3d')
30
                ax.scatter(latent_code.iloc[:, 0], latent_code.iloc[:, 1], latent_code.iloc[:, 2], s=2, marker='o',
31
                           alpha=0.8)
32
            ax.set_xlabel('First Latent Dimension')
33
            ax.set_ylabel('Second Latent Dimension')
34
            ax.set_zlabel('Third Latent Dimension')
35
        elif latent_code_dim == 2:
36
            if have_label:
37
                # Set sample label
38
                disease_id = pd.read_csv(label_file, sep='\t', index_col=0)
39
                latent_code_label = pd.merge(latent_code, disease_id, left_index=True, right_index=True)
40
                colour_setting = pd.read_csv(colour_file, sep='\t')
41
                plt.figure(figsize=(8, 5.5))
42
                for index in range(len(colour_setting)):
43
                    code = colour_setting.iloc[index, 1]
44
                    colour = colour_setting.iloc[index, 0]
45
                    if code in latent_code_label.iloc[:, latent_code_dim].unique():
46
                        latent_code_label_part = latent_code_label[latent_code_label.iloc[:, latent_code_dim] == code]
47
                        plt.scatter(latent_code_label_part.iloc[:, 0], latent_code_label_part.iloc[:, 1], s=2,
48
                                    marker='o', alpha=0.8, c=colour, label=code)
49
                plt.legend(ncol=2, markerscale=4, bbox_to_anchor=(1, 1), loc='upper left', frameon=False)
50
            else:
51
                plt.scatter(latent_code.iloc[:, 0], latent_code.iloc[:, 1], s=2, marker='o', alpha=0.8)
52
            plt.xlabel('First Latent Dimension')
53
            plt.ylabel('Second Latent Dimension')
54
        input_file_name = output_path.split('/')[-1]
55
        fig_path = 'results/' + input_file_name + str(latent_code_dim) + 'D_fig.png'
56
        fig_path_svg = 'results/' + input_file_name + str(latent_code_dim) + 'D_fig.svg'
57
        plt.tight_layout()
58
        plt.savefig(fig_path, dpi=300)
59
        plt.savefig(fig_path_svg)