--- a +++ b/openomics/visualization/umap.py @@ -0,0 +1,46 @@ +import numpy as np +import plotly.express as px +import umap + + +def d3_umap(X, y_km, heat=None): + """ + Args: + X: + y_km: + heat: + """ + reducer = umap.UMAP(random_state=1234, n_components=3) + X_embedded = reducer.fit_transform(X) + node_colors = get_node_colormap(y_km) + x, y, z = X_embedded[:, 0], X_embedded[:, 1], X_embedded[:, 2] + + fig = px.scatter_3d(x=x, y=y, z=z, color=node_colors) + fig.show() + return reducer + + +def get_node_colormap(node_label): + """ + Args: + node_label: + """ + if type(node_label) == list: + node_labels = node_label + sorted_node_labels = sorted(set(node_labels), reverse=True) + colors = np.linspace(0, 1, len(sorted_node_labels)) + node_colormap = {f: colors[sorted_node_labels.index(f)] for f in set(node_labels)} + node_colors = [node_colormap[n] if n in node_colormap.keys() else None for n in node_labels] + + elif node_label.dtype == "object": + node_labels = node_label.str.split("|", expand=True)[0] + sorted_node_labels = sorted(node_labels.unique(), reverse=True) + colors = np.linspace(0, 1, len(sorted_node_labels)) + node_colormap = {f: colors[sorted_node_labels.index(f)] for f in node_labels.unique()} + node_colors = [node_colormap[n] if n in node_colormap.keys() else None for n in node_labels] + + elif node_label.dtype == "float": + node_labels = node_label.values + node_colormap = None + node_colors = [n / node_labels.max() for n in node_labels] + return node_colors