[f8624c]: / ai_genomics / utils / save_plotting.py

Download this file

92 lines (74 with data), 2.6 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""
utils.save_plotting
Utils for easier exporting of altair charts
"""
from ai_genomics import PROJECT_DIR
from altair_saver import save
from selenium import webdriver
from webdriver_manager.chrome import ChromeDriverManager
import os
from typing import Iterator
from pathlib import Path
FIGURE_PATH = Path(f"{PROJECT_DIR}/outputs/figures")
DEFAULT_FILETYPES = ["png", "svg", "html"]
def google_chrome_driver_setup():
"""Set up the driver to save figures"""
driver = webdriver.Chrome(ChromeDriverManager().install())
return driver
def create_paths(
path: os.PathLike = FIGURE_PATH, filetypes: Iterator[list] = DEFAULT_FILETYPES
):
"""Checks if the paths exist and if not creates them"""
for filetype in filetypes:
os.makedirs(f"{path}/{filetype}", exist_ok=True)
def save_png(fig, path: os.PathLike, name: str, driver):
"""Save altair chart as a raster png file"""
save(
fig,
f"{path}/png/{name}.png",
method="selenium",
webdriver=driver,
scale_factor=5,
)
def save_html(fig, path: os.PathLike, name: str):
"""Save altair chart as an html file"""
fig.save(f"{path}/html/{name}.html")
def save_svg(fig, path: os.PathLike, name: str, driver):
"""Save altair chart as a vector svg file"""
save(fig, f"{path}/svg/{name}.svg", method="selenium", webdriver=driver)
class AltairSaver:
"""
Class helping to easily save altair charts
"""
def __init__(
self,
path: os.PathLike = FIGURE_PATH,
filetypes: Iterator[list] = DEFAULT_FILETYPES,
):
self.driver = google_chrome_driver_setup()
self.path = path
self.filetypes = filetypes
def save(
self, fig, name: str, path: os.PathLike = None, filetypes: Iterator[list] = None
):
"""
Saves an altair figure in multiple formats (png, html and svg files)
Args:
fig: altair chart
name: name to save the figure
driver: webdriver
path: path where to save the figure
filetype: list of filetypes, eg: ['png', 'svg', 'html']
"""
# Default values
path = self.path if path is None else path
filetypes = self.filetypes if filetypes is None else filetypes
# Check paths
create_paths(path, filetypes)
# Export figures
if "png" in filetypes:
save_png(fig, path, name, self.driver)
if "html" in filetypes:
save_html(fig, path, name)
if "svg" in filetypes:
save_svg(fig, path, name, self.driver)