|
a |
|
b/src/move/visualization/style.py |
|
|
1 |
__all__ = [ |
|
|
2 |
"DEFAULT_DIVERGING_PALETTE", |
|
|
3 |
"DEFAULT_QUALITATIVE_PALETTE", |
|
|
4 |
"DEFAULT_PLOT_STYLE", |
|
|
5 |
"color_cycle", |
|
|
6 |
"style_settings", |
|
|
7 |
] |
|
|
8 |
|
|
|
9 |
from typing import ContextManager, cast |
|
|
10 |
|
|
|
11 |
import matplotlib |
|
|
12 |
import matplotlib.style |
|
|
13 |
from cycler import cycler |
|
|
14 |
from matplotlib.cm import ColormapRegistry |
|
|
15 |
from matplotlib.colors import ListedColormap |
|
|
16 |
|
|
|
17 |
DEFAULT_DIVERGING_PALETTE = "RdYlBu" |
|
|
18 |
DEFAULT_QUALITATIVE_PALETTE = "Dark2" |
|
|
19 |
DEFAULT_PLOT_STYLE = "ggplot" |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
def color_cycle(colormap: str) -> ContextManager: |
|
|
23 |
"""Returns a context manager for using a color cycle in plots. |
|
|
24 |
|
|
|
25 |
Args: |
|
|
26 |
colormap: Name of qualitative color map. |
|
|
27 |
|
|
|
28 |
Returns: |
|
|
29 |
Context manager |
|
|
30 |
""" |
|
|
31 |
registry: ColormapRegistry = matplotlib.colormaps |
|
|
32 |
colormap_ = registry[colormap] |
|
|
33 |
if isinstance(colormap_, ListedColormap): |
|
|
34 |
prop_cycle = cycler(color=colormap_.colors) |
|
|
35 |
return matplotlib.rc_context({"axes.prop_cycle": prop_cycle}) |
|
|
36 |
raise ValueError("Only colormaps that are list of colors supported.") |
|
|
37 |
|
|
|
38 |
|
|
|
39 |
def style_settings(style: str) -> ContextManager: |
|
|
40 |
"""Returns a context manager for setting a plot's style. |
|
|
41 |
|
|
|
42 |
Args: |
|
|
43 |
style: Style name. |
|
|
44 |
|
|
|
45 |
Returns: |
|
|
46 |
Context manager |
|
|
47 |
""" |
|
|
48 |
return cast(ContextManager, matplotlib.style.context(style)) |