--- a
+++ b/plot_sactter.py
@@ -0,0 +1,59 @@
+import pandas as pd
+import matplotlib.pyplot as plt
+from mpl_toolkits.mplot3d import Axes3D
+
+
+def plot_scatter(latent_code, output_path,
+                 label_file='data/PANCAN/GDC-PANCAN_both_samples_tumour_type.tsv',
+                 colour_file='data/TCGA_colors_obvious.tsv', latent_code_dim=2, have_label=True):
+    if latent_code_dim <= 3:
+        if latent_code_dim == 3:
+            # Plot the 3D scatter graph of latent space
+            if have_label:
+                # Set sample label
+                disease_id = pd.read_csv(label_file, sep='\t', index_col=0)
+                latent_code_label = pd.merge(latent_code, disease_id, left_index=True, right_index=True)
+                colour_setting = pd.read_csv(colour_file, sep='\t')
+                fig = plt.figure(figsize=(8, 5.5))
+                ax = fig.add_subplot(111, projection='3d')
+                for index in range(len(colour_setting)):
+                    code = colour_setting.iloc[index, 1]
+                    colour = colour_setting.iloc[index, 0]
+                    if code in latent_code_label.iloc[:, latent_code_dim].unique():
+                        latent_code_label_part = latent_code_label[latent_code_label.iloc[:, latent_code_dim] == code]
+                        ax.scatter(latent_code_label_part.iloc[:, 0], latent_code_label_part.iloc[:, 1],
+                                   latent_code_label_part.iloc[:, 2], s=2, marker='o', alpha=0.8, c=colour, label=code)
+                ax.legend(ncol=2, markerscale=4, bbox_to_anchor=(1, 0.9), loc='upper left', frameon=False)
+            else:
+                fig = plt.figure()
+                ax = fig.add_subplot(111, projection='3d')
+                ax.scatter(latent_code.iloc[:, 0], latent_code.iloc[:, 1], latent_code.iloc[:, 2], s=2, marker='o',
+                           alpha=0.8)
+            ax.set_xlabel('First Latent Dimension')
+            ax.set_ylabel('Second Latent Dimension')
+            ax.set_zlabel('Third Latent Dimension')
+        elif latent_code_dim == 2:
+            if have_label:
+                # Set sample label
+                disease_id = pd.read_csv(label_file, sep='\t', index_col=0)
+                latent_code_label = pd.merge(latent_code, disease_id, left_index=True, right_index=True)
+                colour_setting = pd.read_csv(colour_file, sep='\t')
+                plt.figure(figsize=(8, 5.5))
+                for index in range(len(colour_setting)):
+                    code = colour_setting.iloc[index, 1]
+                    colour = colour_setting.iloc[index, 0]
+                    if code in latent_code_label.iloc[:, latent_code_dim].unique():
+                        latent_code_label_part = latent_code_label[latent_code_label.iloc[:, latent_code_dim] == code]
+                        plt.scatter(latent_code_label_part.iloc[:, 0], latent_code_label_part.iloc[:, 1], s=2,
+                                    marker='o', alpha=0.8, c=colour, label=code)
+                plt.legend(ncol=2, markerscale=4, bbox_to_anchor=(1, 1), loc='upper left', frameon=False)
+            else:
+                plt.scatter(latent_code.iloc[:, 0], latent_code.iloc[:, 1], s=2, marker='o', alpha=0.8)
+            plt.xlabel('First Latent Dimension')
+            plt.ylabel('Second Latent Dimension')
+        input_file_name = output_path.split('/')[-1]
+        fig_path = 'results/' + input_file_name + str(latent_code_dim) + 'D_fig.png'
+        fig_path_svg = 'results/' + input_file_name + str(latent_code_dim) + 'D_fig.svg'
+        plt.tight_layout()
+        plt.savefig(fig_path, dpi=300)
+        plt.savefig(fig_path_svg)