--- a +++ b/test/test_data.py @@ -0,0 +1,80 @@ +import unittest +import pandas as pd + +from src.data import load_chia, load_fb, train_test_val_split, get_chia_annotations + + +class DataTestCase(unittest.TestCase): + def setUp(self): + """Load Chia dataset as a Pandas dataframe""" + self.df_chia = load_chia() + self.df_fb = load_fb() + + def test_load_chia_returns_dataframe(self): + """Tests if load_chia returns a Pandas dataframe""" + self.assertIsInstance(self.df_chia, pd.DataFrame) + + def test_chia_number_of_rows(self): + """Tests if load_chia returns a dataframe with 2000 rows""" + self.assertEqual(self.df_chia.shape[0], 2000) + + def test_chia_number_of_columns(self): + """Tests if load_chia returns a dataframe with 12 columns""" + self.assertEqual(self.df_chia.shape[1], 12) + + def test_fb_returns_dict(self): + """Tests if load_fb returns a dictionary""" + self.assertIsInstance(self.df_fb, dict) + + def test_fb_keys(self): + """Tests if load_fb returns a dictionary with 3 keys""" + self.assertEqual(len(self.df_fb), 3) + self.assertIn("train", self.df_fb) + self.assertIn("test", self.df_fb) + self.assertIn("val", self.df_fb) + + def test_fb_num_rows(self): + """Tests if load_fb returns a dictionary with 3 dataframes of correct sizes""" + self.assertEqual(self.df_fb["train"].shape[0], 1243) + self.assertEqual(self.df_fb["test"].shape[0], 10116) + self.assertEqual(self.df_fb["val"].shape[0], 376) + + def test_train_test_val_split(self): + """Tests if train_test_val_split returns a dictionary with 3 keys""" + self.assertEqual(len(train_test_val_split(self.df_chia)), 3) + + def test_train_test_val_split_keys(self): + """Tests if train_test_val_split returns a dictionary with train, test, and val keys""" + self.assertIn("train", train_test_val_split(self.df_chia)) + self.assertIn("test", train_test_val_split(self.df_chia)) + self.assertIn("val", train_test_val_split(self.df_chia)) + + def test_train_test_val_split_wrong_ratios(self): + """Tests if train_test_val_split raises AssertionError when sum of ratios is not 100""" + with self.assertRaises(AssertionError): + train_test_val_split(self.df_chia, ratio=(70, 20, 11)) + + def test_train_test_val_split_sizes_of_splits(self): + """Tests if train_test_val_split returns a dictionary with train, test, and val splits of correct sizes""" + splits = train_test_val_split(self.df_chia, ratio=(70, 20, 10)) + self.assertEqual(splits["train"].shape[0], 1400) + self.assertEqual(splits["test"].shape[0], 400) + self.assertEqual(splits["val"].shape[0], 200) + + def test_get_chia_annotations_returns_list_of_tuples(self): + """Tests if get_chia_annotations returns a list of tuples""" + self.assertIsInstance(get_chia_annotations("drugs"), list) + self.assertIsInstance(get_chia_annotations("drugs")[0], tuple) + + def test_get_chia_annotations_returns_limited_num_rows(self): + """Tests if get_chia_annotations returns a list of tuples of the correct length""" + self.assertEqual(len(get_chia_annotations("drugs", n=10)), 10) + + def test_get_chia_annotations_raises_error_for_wrong_entity(self): + """Tests if get_chia_annotations raises ValueError for wrong entity""" + with self.assertRaises(AssertionError): + get_chia_annotations("wrong_entity") + + +if __name__ == "__main__": + unittest.main()