|
a |
|
b/src/move/tasks/analyze_latent.py |
|
|
1 |
__all__ = ["analyze_latent"] |
|
|
2 |
|
|
|
3 |
import re |
|
|
4 |
from pathlib import Path |
|
|
5 |
from typing import Sized, cast |
|
|
6 |
|
|
|
7 |
import hydra |
|
|
8 |
import numpy as np |
|
|
9 |
import pandas as pd |
|
|
10 |
import torch |
|
|
11 |
from sklearn.base import TransformerMixin |
|
|
12 |
|
|
|
13 |
import move.visualization as viz |
|
|
14 |
from move.analysis.metrics import ( |
|
|
15 |
calculate_accuracy, |
|
|
16 |
calculate_cosine_similarity, |
|
|
17 |
) |
|
|
18 |
from move.conf.schema import AnalyzeLatentConfig, MOVEConfig |
|
|
19 |
from move.core.logging import get_logger |
|
|
20 |
from move.core.typing import FloatArray |
|
|
21 |
from move.data import io |
|
|
22 |
from move.data.dataloaders import MOVEDataset, make_dataloader |
|
|
23 |
from move.data.perturbations import ( |
|
|
24 |
perturb_categorical_data, |
|
|
25 |
perturb_continuous_data, |
|
|
26 |
) |
|
|
27 |
from move.data.preprocessing import one_hot_encode_single |
|
|
28 |
from move.models.vae import VAE |
|
|
29 |
from move.training.training_loop import TrainingLoopOutput |
|
|
30 |
|
|
|
31 |
|
|
|
32 |
def find_feature_values( |
|
|
33 |
feature_name: str, |
|
|
34 |
feature_names_lists: list[list[str]], |
|
|
35 |
feature_values: list[FloatArray], |
|
|
36 |
) -> tuple[int, FloatArray]: |
|
|
37 |
"""Look for the feature in the list of datasets and returns its values. |
|
|
38 |
|
|
|
39 |
Args: |
|
|
40 |
feature_name: Look-up key |
|
|
41 |
feature_names_lists: List of lists with feature names for each dataset |
|
|
42 |
feature_values: List of data arrays, each representing a dataset |
|
|
43 |
|
|
|
44 |
Raises: |
|
|
45 |
KeyError: If feature does not exist in any dataset |
|
|
46 |
|
|
|
47 |
Returns: |
|
|
48 |
Tuple containing (1) index of dataset containing feature and (2) |
|
|
49 |
values corresponding to the feature |
|
|
50 |
""" |
|
|
51 |
_dataset_index, feature_index = [None] * 2 |
|
|
52 |
for _dataset_index, feature_names in enumerate(feature_names_lists): |
|
|
53 |
try: |
|
|
54 |
feature_index = feature_names.index(feature_name) |
|
|
55 |
except ValueError: |
|
|
56 |
continue |
|
|
57 |
break |
|
|
58 |
if _dataset_index is not None and feature_index is not None: |
|
|
59 |
return ( |
|
|
60 |
_dataset_index, |
|
|
61 |
np.take(feature_values[_dataset_index], feature_index, axis=1), |
|
|
62 |
) |
|
|
63 |
raise KeyError(f"Feature '{feature_name}' not in any dataset.") |
|
|
64 |
|
|
|
65 |
|
|
|
66 |
def _validate_task_config(task_config: AnalyzeLatentConfig) -> None: |
|
|
67 |
if "_target_" not in task_config.reducer: |
|
|
68 |
raise ValueError("Reducer class not specified properly.") |
|
|
69 |
|
|
|
70 |
|
|
|
71 |
def analyze_latent(config: MOVEConfig) -> None: |
|
|
72 |
"""Train one model to inspect its latent space projections.""" |
|
|
73 |
|
|
|
74 |
logger = get_logger(__name__) |
|
|
75 |
logger.info("Beginning task: analyze latent space") |
|
|
76 |
task_config = cast(AnalyzeLatentConfig, config.task) |
|
|
77 |
_validate_task_config(task_config) |
|
|
78 |
|
|
|
79 |
raw_data_path = Path(config.data.raw_data_path) |
|
|
80 |
interim_path = Path(config.data.interim_data_path) |
|
|
81 |
output_path = Path(config.data.results_path) / "latent_space" |
|
|
82 |
output_path.mkdir(exist_ok=True, parents=True) |
|
|
83 |
|
|
|
84 |
logger.debug("Reading data") |
|
|
85 |
sample_names = io.read_names(raw_data_path / f"{config.data.sample_names}.txt") |
|
|
86 |
cat_list, cat_names, con_list, con_names = io.load_preprocessed_data( |
|
|
87 |
interim_path, |
|
|
88 |
config.data.categorical_names, |
|
|
89 |
config.data.continuous_names, |
|
|
90 |
) |
|
|
91 |
test_dataloader = make_dataloader( |
|
|
92 |
cat_list, |
|
|
93 |
con_list, |
|
|
94 |
shuffle=False, |
|
|
95 |
batch_size=task_config.batch_size, |
|
|
96 |
) |
|
|
97 |
test_dataset = cast(MOVEDataset, test_dataloader.dataset) |
|
|
98 |
df_index = pd.Index(sample_names, name="sample") |
|
|
99 |
|
|
|
100 |
assert task_config.model is not None |
|
|
101 |
device = torch.device("cuda" if task_config.model.cuda else "cpu") |
|
|
102 |
model: VAE = hydra.utils.instantiate( |
|
|
103 |
task_config.model, |
|
|
104 |
continuous_shapes=test_dataset.con_shapes, |
|
|
105 |
categorical_shapes=test_dataset.cat_shapes, |
|
|
106 |
) |
|
|
107 |
|
|
|
108 |
logger.debug(f"Model: {model}") |
|
|
109 |
|
|
|
110 |
model_path = output_path / "model.pt" |
|
|
111 |
if model_path.exists(): |
|
|
112 |
logger.debug("Re-loading model") |
|
|
113 |
model.load_state_dict(torch.load(model_path)) |
|
|
114 |
model.to(device) |
|
|
115 |
else: |
|
|
116 |
logger.debug("Training model") |
|
|
117 |
|
|
|
118 |
model.to(device) |
|
|
119 |
train_dataloader = make_dataloader( |
|
|
120 |
cat_list, |
|
|
121 |
con_list, |
|
|
122 |
shuffle=True, |
|
|
123 |
batch_size=task_config.batch_size, |
|
|
124 |
drop_last=True, |
|
|
125 |
) |
|
|
126 |
output: TrainingLoopOutput = hydra.utils.call( |
|
|
127 |
task_config.training_loop, |
|
|
128 |
model=model, |
|
|
129 |
train_dataloader=train_dataloader, |
|
|
130 |
) |
|
|
131 |
losses = output[:-1] |
|
|
132 |
torch.save(model.state_dict(), model_path) |
|
|
133 |
logger.info("Generating visualizations") |
|
|
134 |
logger.debug("Generating plot: loss curves") |
|
|
135 |
fig = viz.plot_loss_curves(losses) |
|
|
136 |
fig_path = str(output_path / "loss_curve.png") |
|
|
137 |
fig.savefig(fig_path, bbox_inches="tight") |
|
|
138 |
fig_df = pd.DataFrame(dict(zip(viz.LOSS_LABELS, losses))) |
|
|
139 |
fig_df.index.name = "epoch" |
|
|
140 |
fig_df.to_csv(output_path / "loss_curve.tsv", sep="\t") |
|
|
141 |
|
|
|
142 |
model.eval() |
|
|
143 |
|
|
|
144 |
logger.info("Projecting into latent space") |
|
|
145 |
latent_space = model.project(test_dataloader) |
|
|
146 |
reducer: TransformerMixin = hydra.utils.instantiate(task_config.reducer) |
|
|
147 |
embedding = reducer.fit_transform(latent_space) |
|
|
148 |
|
|
|
149 |
mappings_path = interim_path / "mappings.json" |
|
|
150 |
if mappings_path.exists(): |
|
|
151 |
mappings = io.load_mappings(mappings_path) |
|
|
152 |
else: |
|
|
153 |
mappings = {} |
|
|
154 |
|
|
|
155 |
fig_df = pd.DataFrame( |
|
|
156 |
np.take(embedding, [0, 1], axis=1), |
|
|
157 |
columns=["dim0", "dim1"], |
|
|
158 |
index=df_index, |
|
|
159 |
) |
|
|
160 |
|
|
|
161 |
for feature_name in task_config.feature_names: |
|
|
162 |
logger.debug(f"Generating plot: latent space + '{feature_name}'") |
|
|
163 |
is_categorical = False |
|
|
164 |
try: |
|
|
165 |
dataset_index, feature_values = find_feature_values( |
|
|
166 |
feature_name, cat_names, cat_list |
|
|
167 |
) |
|
|
168 |
is_categorical = True |
|
|
169 |
except KeyError: |
|
|
170 |
try: |
|
|
171 |
dataset_index, feature_values = find_feature_values( |
|
|
172 |
feature_name, con_names, con_list |
|
|
173 |
) |
|
|
174 |
except KeyError: |
|
|
175 |
logger.warning(f"Feature '{feature_name}' not found in any dataset.") |
|
|
176 |
continue |
|
|
177 |
|
|
|
178 |
if is_categorical: |
|
|
179 |
# Convert one-hot encoding to category codes |
|
|
180 |
is_nan = feature_values.sum(axis=1) == 0 |
|
|
181 |
feature_values = np.argmax(feature_values, axis=1) |
|
|
182 |
|
|
|
183 |
dataset_name = config.data.categorical_names[dataset_index] |
|
|
184 |
feature_mapping = { |
|
|
185 |
str(code): category for category, code in mappings[dataset_name].items() |
|
|
186 |
} |
|
|
187 |
fig = viz.plot_latent_space_with_cat( |
|
|
188 |
embedding, |
|
|
189 |
feature_name, |
|
|
190 |
feature_values, |
|
|
191 |
feature_mapping, |
|
|
192 |
is_nan, |
|
|
193 |
) |
|
|
194 |
fig_df[feature_name] = np.where(is_nan, np.nan, feature_values) |
|
|
195 |
else: |
|
|
196 |
feature_values = feature_values |
|
|
197 |
fig = viz.plot_latent_space_with_con( |
|
|
198 |
embedding, feature_name, feature_values |
|
|
199 |
) |
|
|
200 |
fig_df[feature_name] = np.where(feature_values == 0, np.nan, feature_values) |
|
|
201 |
|
|
|
202 |
# Remove non-alpha characters |
|
|
203 |
safe_feature_name = re.sub(r"[^\w\s]", "", feature_name) |
|
|
204 |
fig_path = str(output_path / f"latent_space_{safe_feature_name}.png") |
|
|
205 |
fig.savefig(fig_path, bbox_inches="tight") |
|
|
206 |
|
|
|
207 |
fig_df.to_csv(output_path / "latent_space.tsv", sep="\t") |
|
|
208 |
|
|
|
209 |
logger.info("Reconstructing") |
|
|
210 |
cat_recons, con_recons = model.reconstruct(test_dataloader) |
|
|
211 |
con_recons = np.split(con_recons, np.cumsum(model.continuous_shapes[:-1]), axis=1) |
|
|
212 |
logger.info("Computing reconstruction metrics") |
|
|
213 |
scores = [] |
|
|
214 |
labels = config.data.categorical_names + config.data.continuous_names |
|
|
215 |
for cat, cat_recon in zip(cat_list, cat_recons): |
|
|
216 |
accuracy = calculate_accuracy(cat, cat_recon) |
|
|
217 |
scores.append(accuracy) |
|
|
218 |
for con, con_recon in zip(con_list, con_recons): |
|
|
219 |
cosine_sim = calculate_cosine_similarity(con, con_recon) |
|
|
220 |
scores.append(cosine_sim) |
|
|
221 |
|
|
|
222 |
logger.debug("Generating plot: reconstruction metrics") |
|
|
223 |
|
|
|
224 |
plot_scores = [np.ma.compressed(np.ma.masked_equal(each, 0)) for each in scores] |
|
|
225 |
fig = viz.plot_metrics_boxplot(plot_scores, labels) |
|
|
226 |
fig_path = str(output_path / "reconstruction_metrics.png") |
|
|
227 |
fig.savefig(fig_path, bbox_inches="tight") |
|
|
228 |
fig_df = pd.DataFrame(dict(zip(labels, scores)), index=df_index) |
|
|
229 |
fig_df.to_csv(output_path / "reconstruction_metrics.tsv", sep="\t") |
|
|
230 |
|
|
|
231 |
logger.info("Computing feature importance") |
|
|
232 |
num_samples = len(cast(Sized, test_dataloader.sampler)) |
|
|
233 |
for i, dataset_name in enumerate(config.data.categorical_names): |
|
|
234 |
logger.debug(f"Generating plot: feature importance '{dataset_name}'") |
|
|
235 |
na_value = one_hot_encode_single(mappings[dataset_name], None) |
|
|
236 |
dataloaders = perturb_categorical_data( |
|
|
237 |
test_dataloader, config.data.categorical_names, dataset_name, na_value |
|
|
238 |
) |
|
|
239 |
num_features = len(dataloaders) |
|
|
240 |
z = model.project(test_dataloader) |
|
|
241 |
diffs = np.empty((num_samples, num_features)) |
|
|
242 |
for j, dataloader in enumerate(dataloaders): |
|
|
243 |
z_perturb = model.project(dataloader) |
|
|
244 |
diffs[:, j] = np.sum(z_perturb - z, axis=1) |
|
|
245 |
feature_mapping = { |
|
|
246 |
str(code): category for category, code in mappings[dataset_name].items() |
|
|
247 |
} |
|
|
248 |
fig = viz.plot_categorical_feature_importance( |
|
|
249 |
diffs, cat_list[i], cat_names[i], feature_mapping |
|
|
250 |
) |
|
|
251 |
fig_path = str(output_path / f"feat_importance_{dataset_name}.png") |
|
|
252 |
fig.savefig(fig_path, bbox_inches="tight") |
|
|
253 |
fig_df = pd.DataFrame(diffs, columns=cat_names[i], index=df_index) |
|
|
254 |
fig_df.to_csv(output_path / f"feat_importance_{dataset_name}.tsv", sep="\t") |
|
|
255 |
|
|
|
256 |
for i, dataset_name in enumerate(config.data.continuous_names): |
|
|
257 |
logger.debug(f"Generating plot: feature importance '{dataset_name}'") |
|
|
258 |
dataloaders = perturb_continuous_data( |
|
|
259 |
test_dataloader, config.data.continuous_names, dataset_name, 0.0 |
|
|
260 |
) |
|
|
261 |
num_features = len(dataloaders) |
|
|
262 |
z = model.project(test_dataloader) |
|
|
263 |
diffs = np.empty((num_samples, num_features)) |
|
|
264 |
for j, dataloader in enumerate(dataloaders): |
|
|
265 |
z_perturb = model.project(dataloader) |
|
|
266 |
diffs[:, j] = np.sum(z_perturb - z, axis=1) |
|
|
267 |
fig = viz.plot_continuous_feature_importance(diffs, con_list[i], con_names[i]) |
|
|
268 |
fig_path = str(output_path / f"feat_importance_{dataset_name}.png") |
|
|
269 |
fig.savefig(fig_path, bbox_inches="tight") |
|
|
270 |
fig_df = pd.DataFrame(diffs, columns=con_names[i], index=df_index) |
|
|
271 |
fig_df.to_csv(output_path / f"feat_importance_{dataset_name}.tsv", sep="\t") |