--- a +++ b/tests/test_data.py @@ -0,0 +1,81 @@ +""" +Test suite. +""" + +import unittest +import itertools +import torch + +from continual.src.utils import data_processing + +BATCH_SIZES = (1, 10, 100) +SEQ_LENS = (4, 12, 48) +N_VARS = (2, 10, 100) +N_CLASSES = (2, 10) +N_LAYERS = (1, 2, 3, 4) +HIDDEN_SIZES = (32, 64, 128) + + +DEMOGRAPHICS = [ + "age", + "gender", + "ethnicity", + "region", + "time_year", + "time_season", + "time_month", +] +OUTCOMES = ["ARF", "shock", "mortality"] +DATASETS = ["MIMIC", "eICU"] + + +class TestDataLoadingMethods(unittest.TestCase): + """ + Data loading tests. + """ + + def test_modalfeatvalfromseq(self): + """ + Test that mode of correct dim is returned. + """ + for n_samples in BATCH_SIZES: + for seq_len in SEQ_LENS: + for n_feats in N_VARS: + for i in range(n_feats): + sim_data = ( + torch.randint(0, 1, (n_samples, seq_len, n_feats)) + .clone() + .detach() + .numpy() + ) + modes = data_processing.get_modes(sim_data, feat=i) + self.assertEqual(modes.shape, torch.Size([n_samples])) + + +# CL task split tests +class TestCLConstructionMethods(unittest.TestCase): + """ + Test construction of Continual Learning task splits. + """ + + def ttest_taskidsnonoverlap(self): + for dataset in DATASETS: + for experiment in OUTCOMES: + for demographic in DEMOGRAPHICS: + # JA: implement + tasks = data_processing.load_data(dataset, demographic, experiment) + for pair in itertools.combinations(tasks, repeat=2): + self.assertTrue(pair[0][:, 0].intersection(pair[0][:, 0]) == {}) + + def ttest_tasktargets(self): + for dataset in DATASETS: + for experiment in OUTCOMES: + for demographic in DEMOGRAPHICS: + # JA: implement + tasks = data_processing.load_data(dataset, demographic, experiment) + for task in tasks: + self.assertTrue(len(task[:, -1].unique()) == 2) + + +if __name__ == "__main__": + unittest.main()