Diff of /test/test_maui.py [000000] .. [433586]

Switch to unified view

a b/test/test_maui.py
1
import os
2
import pytest
3
import tempfile
4
from unittest import mock
5
6
import numpy as np
7
import pandas as pd
8
9
from maui import Maui
10
from maui.maui_warnings import MauiWarning
11
12
13
samples = [f"Sample_{i}" for i in range(10)]
14
15
n_features_1 = 20
16
df1 = pd.DataFrame(
17
    np.random.randn(n_features_1, len(samples)),
18
    columns=samples,
19
    index=[f"data1_feature_{i}" for i in range(n_features_1)],
20
)
21
n_features_2 = 6
22
df2 = pd.DataFrame(
23
    np.random.randn(n_features_2, len(samples)),
24
    columns=samples,
25
    index=[f"data2_feature_{i}" for i in range(n_features_2)],
26
)
27
28
df_empty = pd.DataFrame(
29
    np.random.randn(0, len(samples)),
30
    columns=samples,
31
    index=[f"data0_feature_{i}" for i in range(0)],
32
)
33
34
35
def test_validate_X_fails_if_not_dict():
36
    maui_model = Maui()
37
    with pytest.raises(ValueError):
38
        maui_model._validate_X([1, 2, 3])
39
40
41
def test_validate_X_fails_if_samples_mismatch():
42
    maui_model = Maui()
43
    with pytest.raises(ValueError):
44
        df2_bad = df2.iloc[:, :2]
45
        data_with_mismatching_samples = {"a": df1, "b": df2_bad}
46
        maui_model._validate_X(data_with_mismatching_samples)
47
48
49
def test_validate_X_fails_if_some_data_empty():
50
    maui_model = Maui()
51
    with pytest.raises(ValueError):
52
        maui_model._validate_X({"a": df1, "e": df_empty})
53
54
55
def test_validate_X_returns_true_on_valid_data():
56
    maui_model = Maui()
57
    valid_data = {"a": df1, "b": df2}
58
    assert maui_model._validate_X(valid_data)
59
60
61
def test_dict2array():
62
    maui_model = Maui()
63
    arr = maui_model._dict2array({"data1": df1, "data2": df2})
64
    assert arr.shape[0] == len(df1.columns)
65
    assert arr.shape[1] == len(df1.index) + len(df2.index)
66
67
68
def test_maui_saves_feature_correlations():
69
    maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1)
70
    z = maui_model.fit_transform({"d1": df1, "d2": df2})
71
    r = maui_model.get_feature_correlations()
72
    assert r is not None
73
    assert hasattr(maui_model, "feature_correlations_")
74
75
76
def test_maui_saves_w():
77
    maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1)
78
    z = maui_model.fit_transform({"d1": df1, "d2": df2})
79
    w = maui_model.get_linear_weights()
80
    assert w is not None
81
    assert hasattr(maui_model, "w_")
82
83
84
def test_maui_saves_neural_weight_product():
85
    maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1)
86
    z = maui_model.fit_transform({"d1": df1, "d2": df2})
87
    nwp = maui_model.get_neural_weight_product()
88
    assert nwp is not None
89
    assert hasattr(maui_model, "nwp_")
90
91
    print(maui_model.encoder.summary())
92
93
    w1 = maui_model.encoder.layers[2].get_weights()[0]
94
    w2 = maui_model.encoder.layers[3].get_weights()[0]
95
96
    nwp_11 = np.dot(w1[0, :], w2[:, 0])
97
    assert np.allclose(nwp_11, nwp.iloc[0, 0])
98
99
100
def test_maui_updates_neural_weight_product_when_training():
101
    maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1)
102
103
    z_before = maui_model.fit_transform({"d1": df1, "d2": df2})
104
    nwp_before_fine_tuning = maui_model.get_neural_weight_product()
105
106
    maui_model.fine_tune({"d1": df1, "d2": df2})
107
    z_after = maui_model.transform({"d1": df1, "d2": df2})
108
    nwp_after_fine_tuning = maui_model.get_neural_weight_product()
109
110
    assert not np.allclose(z_before, z_after)
111
    assert not np.allclose(nwp_before_fine_tuning, nwp_after_fine_tuning)
112
113
114
def test_maui_clusters_with_single_k():
115
    maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1)
116
    maui_model.z_ = pd.DataFrame(
117
        np.random.randn(10, 2),
118
        index=[f"sample {i}" for i in range(10)],
119
        columns=["LF1", "LF2"],
120
    )
121
    maui_model.x_ = pd.DataFrame(
122
        np.random.randn(20, 10),
123
        index=[f"feature {i}" for i in range(20)],
124
        columns=[f"sample {i}" for i in range(10)],
125
    )
126
127
    yhat = maui_model.cluster(5)
128
    assert yhat.shape == (10,)
129
130
131
def test_maui_clusters_picks_optimal_k_by_ami():
132
    ami_mock = mock.Mock()
133
    ami_mock.side_effect = [
134
        2,
135
        3,
136
        1,
137
    ]  # the optimal AMI will be given at the second trial
138
    with mock.patch("sklearn.metrics.adjusted_mutual_info_score", ami_mock):
139
        maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1)
140
        maui_model.z_ = pd.DataFrame(
141
            np.random.randn(10, 2),
142
            index=[f"sample {i}" for i in range(10)],
143
            columns=["LF1", "LF2"],
144
        )
145
        maui_model.x_ = pd.DataFrame(
146
            np.random.randn(20, 10),
147
            index=[f"feature {i}" for i in range(20)],
148
            columns=[f"sample {i}" for i in range(10)],
149
        )
150
151
        the_y = pd.Series(np.arange(10), index=maui_model.z_.index)
152
153
        maui_model.cluster(
154
            ami_y=the_y, optimal_k_range=[1, 2, 3]
155
        )  # the second trial is k=2
156
        print(maui_model.kmeans_scores)
157
        assert maui_model.optimal_k_ == 2
158
159
160
def test_maui_clusters_picks_optimal_k_by_silhouette():
161
    silhouette_mock = mock.Mock()
162
    silhouette_mock.side_effect = [
163
        2,
164
        3,
165
        1,
166
    ]  # the optimal silhouette will be given at the second trial
167
    with mock.patch("sklearn.metrics.silhouette_score", silhouette_mock):
168
        maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1)
169
        maui_model.z_ = pd.DataFrame(
170
            np.random.randn(10, 2),
171
            index=[f"sample {i}" for i in range(10)],
172
            columns=["LF1", "LF2"],
173
        )
174
        maui_model.x_ = pd.DataFrame(
175
            np.random.randn(20, 10),
176
            index=[f"feature {i}" for i in range(20)],
177
            columns=[f"sample {i}" for i in range(10)],
178
        )
179
        maui_model.cluster(
180
            optimal_k_method="silhouette", optimal_k_range=[1, 2, 3]
181
        )  # the second trial is k=2
182
183
        assert maui_model.optimal_k_ == 2
184
185
186
def test_maui_clusters_picks_optimal_k_with_custom_scoring():
187
    scorer = mock.Mock()
188
    scorer.side_effect = [2, 3, 1]  # the optimal AMI will be given at the second trial
189
    scorer.__name__ = "mock_scorer"
190
191
    maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1)
192
    maui_model.z_ = pd.DataFrame(
193
        np.random.randn(10, 2),
194
        index=[f"sample {i}" for i in range(10)],
195
        columns=["LF1", "LF2"],
196
    )
197
    maui_model.x_ = pd.DataFrame(
198
        np.random.randn(20, 10),
199
        index=[f"feature {i}" for i in range(20)],
200
        columns=[f"sample {i}" for i in range(10)],
201
    )
202
    maui_model.cluster(
203
        optimal_k_method=scorer, optimal_k_range=[1, 2, 3]
204
    )  # the second trial is k=2
205
206
    assert maui_model.optimal_k_ == 2
207
208
209
def test_maui_computes_roc_and_auc():
210
    maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1)
211
    maui_model.z_ = pd.DataFrame(
212
        [
213
            [0, 1, 1, 1, 0, 1, 1, 0, 0],
214
            [1, 0, 0, 0, 0, 0, 1, 1, 0],
215
            [1, 0, 1, 0, 0, 0, 1, 1, 0],
216
            [1, 0, 0, 1, 0, 0, 1, 1, 0],
217
            [1, 0, 0, 0, 1, 1, 1, 1, 0],
218
            [1, 1, 1, 0, 0, 0, 1, 1, 1],
219
        ],
220
        index=[f"sample {i}" for i in range(6)],
221
        columns=[f"LF{i}" for i in range(9)],
222
    )
223
    y = pd.Series(["a", "b", "a", "c", "b", "c"], index=maui_model.z_.index)
224
    rocs = maui_model.compute_roc(y, cv_folds=2)
225
    assert rocs == maui_model.roc_curves_
226
    assert "a" in rocs
227
    assert "b" in rocs
228
    assert "c" in rocs
229
    assert "mean" in rocs
230
231
    aucs = maui_model.compute_auc(y, cv_folds=2)
232
    assert aucs == maui_model.aucs_
233
234
235
def test_maui_clusters_only_samples_in_y_index_when_optimizing():
236
    np.random.seed(0)
237
    maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1)
238
    maui_model.z_ = pd.DataFrame(
239
        np.random.randn(10, 2),
240
        index=[f"sample {i}" for i in range(10)],
241
        columns=["LF1", "LF2"],
242
    )
243
    maui_model.x_ = pd.DataFrame(
244
        np.random.randn(20, 10),
245
        index=[f"feature {i}" for i in range(20)],
246
        columns=[f"sample {i}" for i in range(10)],
247
    )
248
249
    y = pd.Series(
250
        ["a", "a", "a", "b", "b", "b"], index=[f"sample {i}" for i in range(6)]
251
    )
252
253
    yhat = maui_model.cluster(ami_y=y, optimal_k_range=[1, 2, 3])
254
    assert set(yhat.index) == set(y.index)
255
256
257
def test_select_clinical_factors():
258
    maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1)
259
    maui_model.z_ = pd.DataFrame(
260
        [
261
            [1, 1, 1, 0, 0, 0, 1, 0, 1],
262
            [1, 1, 1, 1, 0, 1, 1, 1, 0],
263
            [1, 1, 1, 1, 0, 1, 1, 1, 0],
264
            [1, 1, 1, 1, 0, 1, 1, 1, 0],
265
            [1, 1, 1, 1, 0, 1, 1, 1, 0],
266
            [1, 1, 1, 1, 1, 0, 0, 1, 0],
267
            [0, 0, 0, 1, 0, 0, 1, 1, 0],
268
            [0, 0, 0, 1, 0, 0, 1, 1, 0],
269
            [0, 0, 0, 1, 0, 0, 1, 1, 0],
270
            [0, 0, 0, 1, 0, 0, 1, 1, 0],
271
            [0, 0, 0, 1, 0, 1, 1, 1, 1],
272
        ],
273
        index=[f"sample {i}" for i in range(11)],
274
        columns=[f"LF{i}" for i in range(9)],
275
    )  # here the first 3 factors separate the groups and the last 6 do not
276
277
    durations = [
278
        1,
279
        2,
280
        3,
281
        4,
282
        5,
283
        6,
284
        1000,
285
        2000,
286
        3000,
287
        4000,
288
        5000,
289
    ]  # here the first 3 have short durations, the last 3 longer ones
290
    observed = [True] * 11  # all events observed
291
    survival = pd.DataFrame(
292
        dict(duration=durations, observed=observed),
293
        index=[f"sample {i}" for i in range(11)],
294
    )
295
296
    z_clin = maui_model.select_clinical_factors(survival, cox_penalizer=1, alpha=0.1)
297
    assert "LF0" in z_clin.columns
298
    assert "LF5" not in z_clin.columns
299
300
301
def test_maui_computes_harrells_c():
302
    maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1)
303
    maui_model.z_ = pd.DataFrame(
304
        [
305
            [1, 1, 1, 0, 0, 0, 1, 0, 1],
306
            [1, 1, 1, 1, 0, 1, 1, 1, 0],
307
            [1, 1, 1, 1, 0, 1, 1, 1, 0],
308
            [1, 1, 1, 1, 0, 1, 1, 1, 0],
309
            [1, 1, 1, 1, 0, 1, 1, 1, 0],
310
            [1, 1, 1, 1, 1, 0, 0, 1, 0],
311
            [0, 0, 0, 1, 0, 0, 1, 1, 0],
312
            [0, 0, 0, 1, 0, 0, 1, 1, 0],
313
            [0, 0, 0, 1, 0, 0, 1, 1, 0],
314
            [0, 0, 0, 1, 0, 0, 1, 1, 0],
315
            [0, 0, 0, 1, 0, 1, 1, 1, 1],
316
        ],
317
        index=[f"sample {i}" for i in range(11)],
318
        columns=[f"LF{i}" for i in range(9)],
319
    )  # here the first 3 factors separate the groups and the last 6 do not
320
321
    durations = [
322
        1,
323
        2,
324
        3,
325
        4,
326
        5,
327
        6,
328
        1000,
329
        2000,
330
        3000,
331
        4000,
332
        5000,
333
    ]  # here the first 3 have short durations, the last 3 longer ones
334
    observed = [True] * 11  # all events observed
335
    survival = pd.DataFrame(
336
        dict(duration=durations, observed=observed),
337
        index=[f"sample {i}" for i in range(11)],
338
    )
339
    cs = maui_model.c_index(
340
        survival,
341
        clinical_only=True,
342
        duration_column="duration",
343
        observed_column="observed",
344
        cox_penalties=[0.1, 1, 10, 100, 1000, 10000],
345
        cv_folds=3,
346
        sel_clin_alpha=0.1,
347
        sel_clin_penalty=1,
348
    )
349
    print(cs)
350
    assert np.allclose(cs, [0.5, 0.8, 0.5], atol=0.05)
351
352
353
def test_maui_produces_same_prediction_when_run_twice():
354
    """This is to show the maui encoder model picks the mean of
355
    the distribution, not a sample."""
356
    maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1)
357
    maui_model = maui_model.fit({"d1": df1, "d2": df2})
358
    z1 = maui_model.transform({"d1": df1, "d2": df2})
359
    z2 = maui_model.transform({"d1": df1, "d2": df2})
360
    assert np.allclose(z1, z2)
361
362
363
def test_maui_produces_different_prediction_when_run_twice_with_sampling():
364
    """This is to show the maui encoder model picks the mean of
365
    the distribution, not a sample."""
366
    maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1)
367
    maui_model = maui_model.fit({"d1": df1, "d2": df2})
368
    z1 = maui_model.transform({"d1": df1, "d2": df2}, encoder="sample")
369
    z2 = maui_model.transform({"d1": df1, "d2": df2}, encoder="sample")
370
    assert not np.allclose(z1, z2)
371
372
373
def test_maui_produces_nonnegative_zs_if_relu_embedding_true():
374
    maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1, relu_embedding=True)
375
    maui_model = maui_model.fit({"d1": df1, "d2": df2})
376
    z1 = maui_model.transform({"d1": df1, "d2": df2})
377
    assert np.all(z1 >= 0)
378
379
380
def test_maui_produces_pos_and_neg_zs_if_relu_embedding_false():
381
    maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1, relu_embedding=False)
382
    maui_model = maui_model.fit({"d1": df1, "d2": df2})
383
    z1 = maui_model.transform({"d1": df1, "d2": df2})
384
    assert not np.all(z1 >= 0)
385
386
387
def test_maui_runs_with_deep_not_stacked_vae():
388
    maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1, architecture="deep")
389
    z = maui_model.fit_transform({"d1": df1, "d2": df2})
390
391
392
def test_maui_complains_if_wrong_architecture():
393
    with pytest.raises(ValueError):
394
        maui_model = Maui(
395
            n_hidden=[10], n_latent=2, epochs=1, architecture="wrong value"
396
        )
397
398
399
def test_maui_supports_single_layer_vae():
400
    maui_model = Maui(n_hidden=None, n_latent=2, epochs=1)
401
    maui_model = maui_model.fit({"d1": df1, "d2": df2})
402
    z1 = maui_model.transform({"d1": df1, "d2": df2})
403
404
405
def test_maui_supports_not_deep_deep_vae():
406
    maui_model = Maui(n_hidden=None, n_latent=2, epochs=1, architecture="deep")
407
    z = maui_model.fit_transform({"d1": df1, "d2": df2})
408
409
410
def test_maui_drops_unexplanatody_factors_by_r2():
411
    maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1)
412
    maui_model.z_ = pd.DataFrame(
413
        [
414
            [1, 1, 1, 0, 0, 0, 1, 0, 0],
415
            [1, 1, 1, 1, 0, 1, 1, 1, 0],
416
            [1, 1, 1, 1, 0, 1, 1, 1, 0],
417
            [1, 1, 1, 1, 0, 1, 1, 1, 0],
418
            [1, 1, 1, 1, 0, 1, 1, 1, 0],
419
            [1, 1, 1, 1, 1, 0, 0, 1, 0],
420
            [0, 0, 0, 1, 0, 0, 1, 1, 0],
421
            [0, 0, 0, 1, 0, 0, 1, 1, 0],
422
            [0, 0, 0, 1, 0, 0, 1, 1, 0],
423
            [0, 0, 0, 1, 0, 0, 1, 1, 0],
424
            [0, 0, 0, 1, 0, 1, 1, 1, 0],
425
        ],
426
        index=[f"sample {i}" for i in range(11)],
427
        columns=[f"LF{i}" for i in range(9)],
428
        dtype=float,
429
    )  # here the first 8 latent factors have R2 above threshold, the last does not
430
    maui_model.x_ = pd.DataFrame(
431
        [[1], [1], [1], [1], [1], [1], [0], [0], [0], [0], [0]],
432
        index=[f"sample {i}" for i in range(11)],
433
        columns=["Feature 1"],
434
        dtype=float,
435
    )
436
437
    z_filt = maui_model.drop_unexplanatory_factors()
438
439
    assert z_filt.shape[1] == 8
440
441
442
def test_maui_merges_latent_factors():
443
    maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1)
444
    maui_model.z_ = pd.DataFrame(
445
        [
446
            [1, 1, 1, 0, 0, 0, 1, 0, 0],
447
            [1, 1, 1, 1, 0, 1, 1, 1, 0],
448
            [1, 1, 1, 1, 0, 1, 1, 1, 0],
449
            [1, 1, 1, 1, 0, 1, 1, 1, 0],
450
            [1, 1, 1, 1, 0, 1, 1, 1, 0],
451
            [1, 1, 1, 1, 1, 0, 0, 1, 0],
452
            [0, 0, 0, 1, 0, 0, 1, 1, 0],
453
            [0, 0, 0, 1, 0, 0, 1, 1, 0],
454
            [0, 0, 0, 1, 0, 0, 1, 1, 0],
455
            [0, 0, 0, 1, 0, 0, 1, 1, 0],
456
            [0, 0, 0, 1, 0, 1, 1, 1, 0],
457
        ],
458
        index=[f"sample {i}" for i in range(11)],
459
        columns=[f"LF{i}" for i in range(9)],
460
        dtype=float,
461
    )  # expect 0,1,2 to be merged, and 3,7 to be merged
462
463
    z_merged = maui_model.merge_similar_latent_factors(distance_metric="euclidean")
464
    assert z_merged.shape[1] == 6
465
    assert "0_1_2" in z_merged.columns
466
    assert "3_7" in z_merged.columns
467
468
469
def test_maui_merges_latent_factors_by_w():
470
    maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1)
471
    maui_model.z_ = pd.DataFrame(
472
        [
473
            [1, 1, 1, 0, 0, 0, 1, 0, 0],
474
            [1, 1, 1, 1, 0, 1, 1, 1, 0],
475
            [1, 1, 1, 1, 0, 1, 1, 1, 0],
476
            [1, 1, 1, 1, 0, 1, 1, 1, 0],
477
            [1, 1, 1, 1, 0, 1, 1, 1, 0],
478
            [1, 1, 1, 1, 1, 0, 0, 1, 0],
479
            [0, 0, 0, 1, 0, 0, 1, 1, 0],
480
            [0, 0, 0, 1, 0, 0, 1, 1, 0],
481
            [0, 0, 0, 1, 0, 0, 1, 1, 0],
482
            [0, 0, 0, 1, 0, 0, 1, 1, 0],
483
            [0, 0, 0, 1, 0, 1, 1, 1, 0],
484
        ],
485
        index=[f"sample {i}" for i in range(11)],
486
        columns=[f"LF{i}" for i in range(9)],
487
        dtype=float,
488
    )
489
    maui_model.x_ = pd.DataFrame(
490
        [[1], [1], [1], [1], [1], [1], [0], [0], [0], [0], [0]],
491
        index=[f"sample {i}" for i in range(11)],
492
        columns=["Feature 1"],
493
        dtype=float,
494
    )
495
    # with these z and x, expect 0,1,2 and 4,5 and 3,6,7
496
    z_merged = maui_model.merge_similar_latent_factors(
497
        distance_in="w", distance_metric="euclidean"
498
    )
499
    assert z_merged.shape[1] == 4
500
    assert "0_1_2" in z_merged.columns
501
    assert "3_6_7" in z_merged.columns
502
    assert "4_5" in z_merged.columns
503
504
505
def test_maui_merge_latent_factors_complains_if_unknown_merge_by():
506
    maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1)
507
    maui_model.z_ = pd.DataFrame(
508
        [
509
            [1, 1, 1, 0, 0, 0, 1, 0, 0],
510
            [1, 1, 1, 1, 0, 1, 1, 1, 0],
511
            [1, 1, 1, 1, 0, 1, 1, 1, 0],
512
            [1, 1, 1, 1, 0, 1, 1, 1, 0],
513
            [1, 1, 1, 1, 0, 1, 1, 1, 0],
514
            [1, 1, 1, 1, 1, 0, 0, 1, 0],
515
            [0, 0, 0, 1, 0, 0, 1, 1, 0],
516
            [0, 0, 0, 1, 0, 0, 1, 1, 0],
517
            [0, 0, 0, 1, 0, 0, 1, 1, 0],
518
            [0, 0, 0, 1, 0, 0, 1, 1, 0],
519
            [0, 0, 0, 1, 0, 1, 1, 1, 0],
520
        ],
521
        index=[f"sample {i}" for i in range(11)],
522
        columns=[f"LF{i}" for i in range(9)],
523
        dtype=float,
524
    )  # expect 0,1,2 to be merged, and 3,7 to be merged
525
526
    with pytest.raises(Exception):
527
        z_merged = maui_model.merge_similar_latent_factors(
528
            distance_in="xxx", distance_metric="euclidean"
529
        )
530
531
532
def test_maui_can_save_to_folder():
533
    maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1)
534
    maui_model = maui_model.fit({"d1": df1, "d2": df2})
535
    with tempfile.TemporaryDirectory() as tmpdirname:
536
        maui_model.save(tmpdirname)
537
        assert os.path.isfile(os.path.join(tmpdirname, "maui_weights.h5"))
538
        assert os.path.isfile(os.path.join(tmpdirname, "maui_args.json"))
539
540
541
def test_maui_can_load_from_folder():
542
    maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1)
543
    maui_model = maui_model.fit({"d1": df1, "d2": df2})
544
    with tempfile.TemporaryDirectory() as tmpdirname:
545
        maui_model.save(tmpdirname)
546
        maui_model_from_disk = Maui.load(tmpdirname)
547
548
    assert maui_model_from_disk.n_latent == maui_model.n_latent
549
    assert np.allclose(
550
        maui_model.vae.get_weights()[0], maui_model_from_disk.vae.get_weights()[0]
551
    )
552
    assert np.allclose(
553
        maui_model.transform({"d1": df1, "d2": df2}),
554
        maui_model_from_disk.transform({"d1": df1, "d2": df2}),
555
    )
556
557
558
def test_maui_can_print_verbose_training(capsys):
559
    maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1)
560
    maui_model = maui_model.fit({"d1": df1, "d2": df2})
561
562
    stdout, stderr = capsys.readouterr()
563
    assert stdout == ""
564
565
    maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1, verbose=1)
566
    maui_model = maui_model.fit({"d1": df1, "d2": df2})
567
568
    stdout, stderr = capsys.readouterr()
569
    assert "Epoch" in stdout
570
571
572
def test_maui_model_makes_2_layer_vae():
573
    maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1, input_dim=10)
574
    layers_names = [l.name for l in maui_model.vae.layers]
575
576
    assert "hidden_dim_0_mean" in layers_names
577
    assert "latent_mean" in layers_names
578
    assert "decode_hidden_0" in layers_names
579
    assert "reconstruction" in layers_names
580
581
    assert "decode_hidden_1" not in layers_names
582
583
584
def test_maui_model_makes_one_layer_vae():
585
    maui_model = Maui(n_hidden=[], n_latent=2, epochs=1, input_dim=10)
586
    layers_names = [l.name for l in maui_model.vae.layers]
587
588
    print(layers_names)
589
590
    assert layers_names[-1] == "reconstruction"
591
592
    assert not any(
593
        "decode_hidden" in name for name in layers_names
594
    ), "Has a decode hidden..."
595
    assert not any("hidden_dim" in name for name in layers_names), "Has a hidden dim..."
596
597
598
def test_maui_model_validates_feature_names_on_predict_after_fit():
599
    maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1)
600
    maui_model.fit({"d1": df1, "d2": df2})
601
602
    z = maui_model.transform({"d1": df1, "d2": df2})
603
604
    df1_wrong_features = df1.reindex(df1.index[: len(df1.index) - 1])
605
    with pytest.raises(ValueError):
606
        z = maui_model.transform({"df1": df1_wrong_features, "df2": df2})
607
608
609
def test_maui_model_saves_feature_names_to_disk():
610
    maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1)
611
    maui_model = maui_model.fit({"d1": df1, "d2": df2})
612
    with tempfile.TemporaryDirectory() as tmpdirname:
613
        maui_model.save(tmpdirname)
614
        maui_model_from_disk = Maui.load(tmpdirname)
615
    assert maui_model.feature_names == maui_model_from_disk.feature_names
616
617
618
def test_maui_model_loads_model_without_feature_names_from_disk_and_warns():
619
    maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1)
620
    maui_model = maui_model.fit({"d1": df1, "d2": df2})
621
    with tempfile.TemporaryDirectory() as tmpdirname:
622
        maui_model.save(tmpdirname)
623
        os.remove(os.path.join(tmpdirname, "maui_feature_names.txt"))
624
        with pytest.warns(MauiWarning):
625
            maui_model_from_disk = Maui.load(tmpdirname)
626
        assert maui_model_from_disk.feature_names is None
627
628
629
def test_maui_can_fine_tune():
630
    maui_model = Maui(n_hidden=[], n_latent=2, epochs=1)
631
    maui_model = maui_model.fit({"d1": df1, "d2": df2})
632
    maui_model.fine_tune({"d1": df1, "d2": df2}, epochs=1)
633
634
635
def test_maui_complains_if_fine_tune_with_wrong_features():
636
    maui_model = Maui(n_hidden=[], n_latent=2, epochs=1)
637
    maui_model.fit({"d1": df1, "d2": df2})
638
639
    df1_wrong_features = df1.reindex(df1.index[: len(df1.index) - 1])
640
    with pytest.raises(ValueError):
641
        z = maui_model.fine_tune({"df1": df1_wrong_features, "df2": df2})