--- a +++ b/tests/io/test_read.py @@ -0,0 +1,291 @@ +import numpy as np +import pandas as pd +import pytest +from pandas import CategoricalDtype + +from ehrapy.io._read import read_csv, read_fhir, read_h5ad +from tests.conftest import TEST_DATA_PATH + +_TEST_PATH = f"{TEST_DATA_PATH}/io" +_TEST_PATH_H5AD = f"{_TEST_PATH}/h5ad" +_TEST_PATH_FHIR = f"{_TEST_PATH}/fhir/json" + + +def test_read_csv(): + adata = read_csv(dataset_path=f"{_TEST_PATH}/dataset_basic.csv") + matrix = np.array( + [[12, 14, 500, False], [13, 7, 330, False], [14, 10, 800, True], [15, 11, 765, True], [16, 3, 800, True]] + ) + assert adata.X.shape == (5, 4) + assert (adata.X == matrix).all() + assert adata.var_names.to_list() == ["patient_id", "los_days", "b12_values", "survival"] + assert (adata.layers["original"] == matrix).all() + assert id(adata.layers["original"]) != id(adata.X) + + +def test_read_tsv(): + adata = read_csv(dataset_path=f"{_TEST_PATH}/dataset_tsv.tsv", sep="\t") + matrix = np.array( + [ + [12, 54, 185.34, False], + [13, 25, 175.39, True], + [14, 36, 183.29, False], + [15, 44, 173.93, True], + [16, 27, 190.32, True], + ] + ) + assert adata.X.shape == (5, 4) + assert (adata.X == matrix).all() + assert adata.var_names.to_list() == ["patient_id", "age", "height", "gamer"] + assert (adata.layers["original"] == matrix).all() + assert id(adata.layers["original"]) != id(adata.X) + + +def test_read_multiple_csv_to_anndatas(): + adatas = read_csv(dataset_path=f"{_TEST_PATH}") + adata_ids = set(adatas.keys()) + assert all(adata_id in adata_ids for adata_id in {"dataset_non_num_with_missing", "dataset_num_with_missing"}) + assert set(adatas["dataset_non_num_with_missing"].var_names) == { + "indexcol", + "intcol", + "strcol", + "boolcol", + "binary_col", + } + assert set(adatas["dataset_num_with_missing"].var_names) == {"col" + str(i) for i in range(1, 4)} + + +def test_read_multiple_csvs_to_dfs(): + dfs = read_csv(dataset_path=f"{_TEST_PATH}", return_dfs=True) + dfs_ids = set(dfs.keys()) + assert all(id in dfs_ids for id in {"dataset_non_num_with_missing", "dataset_num_with_missing"}) + assert set(dfs["dataset_non_num_with_missing"].columns) == { + "indexcol", + "intcol", + "strcol", + "boolcol", + "binary_col", + "datetime", + } + + +def test_read_multiple_csv_with_obs_only(): + adatas = read_csv( + dataset_path=f"{_TEST_PATH}", + columns_obs_only={"dataset_non_num_with_missing": ["strcol"], "dataset_num_with_missing": ["col1"]}, + ) + adata_ids = set(adatas.keys()) + assert all(adata_id in adata_ids for adata_id in {"dataset_non_num_with_missing", "dataset_num_with_missing"}) + assert set(adatas["dataset_non_num_with_missing"].var_names) == {"indexcol", "intcol", "boolcol", "binary_col"} + assert set(adatas["dataset_num_with_missing"].var_names) == {"col" + str(i) for i in range(2, 4)} + assert all( + obs_name in set(adatas["dataset_non_num_with_missing"].obs.columns) for obs_name in {"datetime", "strcol"} + ) + assert "col1" in set(adatas["dataset_num_with_missing"].obs.columns) + + +def test_read_h5ad(): + adata = read_h5ad(dataset_path=f"{_TEST_PATH_H5AD}/dataset9.h5ad") + + assert adata.X.shape == (4, 3) + assert set(adata.var_names) == {"col" + str(i) for i in range(1, 4)} + assert set(adata.obs.columns) == set() + + +def test_read_multiple_h5ad(): + adatas = read_h5ad(dataset_path=f"{_TEST_PATH_H5AD}") + adata_ids = set(adatas.keys()) + + assert all(adata_id in adata_ids for adata_id in {"dataset8", "dataset9"}) + assert set(adatas["dataset8"].var_names) == {"indexcol", "intcol", "boolcol", "binary_col", "strcol"} + assert set(adatas["dataset9"].var_names) == {"col" + str(i) for i in range(1, 4)} + assert all(obs_name in set(adatas["dataset8"].obs.columns) for obs_name in {"datetime"}) + + +def test_read_csv_without_index_column(): + adata = read_csv(dataset_path=f"{_TEST_PATH}/dataset_index.csv") + matrix = np.array( + [[1, 14, 500, False], [2, 7, 330, False], [3, 10, 800, True], [4, 11, 765, True], [5, 3, 800, True]] + ) + assert adata.X.shape == (5, 4) + assert (adata.X == matrix).all() + assert adata.var_names.to_list() == ["clinic_id", "los_days", "b12_values", "survival"] + assert (adata.layers["original"] == matrix).all() + assert id(adata.layers["original"]) != id(adata.X) + assert list(adata.obs.index) == ["0", "1", "2", "3", "4"] + + +def test_read_csv_with_bools_obs_only(): + adata = read_csv(dataset_path=f"{_TEST_PATH}/dataset_basic.csv", columns_obs_only=["survival", "b12_values"]) + matrix = np.array([[12, 14], [13, 7], [14, 10], [15, 11], [16, 3]]) + assert adata.X.shape == (5, 2) + assert (adata.X == matrix).all() + assert adata.var_names.to_list() == ["patient_id", "los_days"] + assert (adata.layers["original"] == matrix).all() + assert id(adata.layers["original"]) != id(adata.X) + assert set(adata.obs.columns) == {"b12_values", "survival"} + assert pd.api.types.is_bool_dtype(adata.obs["survival"].dtype) + assert pd.api.types.is_numeric_dtype(adata.obs["b12_values"].dtype) + + +def test_read_csv_with_bools_and_cats_obs_only(): + adata = read_csv( + dataset_path=f"{_TEST_PATH}/dataset_bools_and_str.csv", columns_obs_only=["b12_values", "name", "survival"] + ) + matrix = np.array([[1, 14], [2, 7], [3, 10], [4, 11], [5, 3]]) + assert adata.X.shape == (5, 2) + assert (adata.X == matrix).all() + assert adata.var_names.to_list() == ["clinic_id", "los_days"] + assert (adata.layers["original"] == matrix).all() + assert id(adata.layers["original"]) != id(adata.X) + assert set(adata.obs.columns) == {"b12_values", "survival", "name"} + assert pd.api.types.is_bool_dtype(adata.obs["survival"].dtype) + assert pd.api.types.is_numeric_dtype(adata.obs["b12_values"].dtype) + assert isinstance(adata.obs["name"].dtype, CategoricalDtype) + + +def test_set_default_index(): + adata = read_csv(dataset_path=f"{_TEST_PATH}/dataset_index.csv") + assert adata.X.shape == (5, 4) + assert not adata.obs_names.name + assert adata.obs.index.values.tolist() == [f"{i}" for i in range(5)] + + +def test_set_given_str_index(): + adata = read_csv(dataset_path=f"{_TEST_PATH}/dataset_basic.csv", index_column="los_days") + assert adata.X.shape == (5, 3) + assert adata.obs_names.name == "los_days" + assert adata.obs.index.values.tolist() == ["14", "7", "10", "11", "3"] + + +def test_set_given_int_index(): + adata = read_csv(dataset_path=f"{_TEST_PATH}/dataset_basic.csv", index_column=1) + assert adata.X.shape == (5, 3) + assert adata.obs_names.name == "los_days" + assert adata.obs.index.values.tolist() == ["14", "7", "10", "11", "3"] + + +def test_move_single_column_misspelled(): + with pytest.raises(ValueError): + _ = read_csv(dataset_path=f"{_TEST_PATH}/dataset_basic.csv", columns_obs_only=["b11_values"]) + + +def test_move_single_column_to_obs(): + adata = read_csv(dataset_path=f"{_TEST_PATH}/dataset_basic.csv", columns_obs_only=["b12_values"]) + assert adata.X.shape == (5, 3) + assert list(adata.obs.columns) == ["b12_values"] + assert "b12_values" not in list(adata.var_names.values) + + +def test_move_multiple_columns_to_obs(): + adata = read_csv(dataset_path=f"{_TEST_PATH}/dataset_basic.csv", columns_obs_only=["b12_values", "survival"]) + assert adata.X.shape == (5, 2) + assert list(adata.obs.columns) == ["b12_values", "survival"] + assert "b12_values" not in list(adata.var_names.values) and "survival" not in list(adata.var_names.values) + + +def test_read_raises_error_with_duplicates_columns_only_single_1(): + with pytest.raises(ValueError): + _ = read_csv( + dataset_path=f"{_TEST_PATH}/dataset_basic.csv", + columns_obs_only=["survival", "b12_values"], + columns_x_only=["survival", "b12_values"], + ) + + +def test_read_raises_error_with_duplicates_columns_only_single_2(): + with pytest.raises(ValueError): + _ = read_csv( + dataset_path=f"{_TEST_PATH}/dataset_basic.csv", + columns_obs_only=["survival"], + columns_x_only=["survival", "b12_values"], + ) + + +def test_read_raises_error_with_duplicates_columns_only_multiple_1(): + with pytest.raises(ValueError): + _ = read_csv( + dataset_path=f"{_TEST_PATH}", + columns_obs_only={ + "dataset_non_num_with_missing": ["intcol"], + "dataset_num_with_missing": ["col1", "col2"], + }, + columns_x_only={"dataset_non_num_with_missing": ["intcol"]}, + ) + + +def test_read_raises_error_with_duplicates_columns_only_multiple_2(): + with pytest.raises(ValueError): + _ = read_csv( + dataset_path=f"{_TEST_PATH}", + columns_obs_only={ + "dataset_non_num_with_missing": ["intcol"], + "dataset_num_with_missing": ["col1", "col2"], + }, + columns_x_only={"dataset_non_num_with_missing": ["indexcol"], "dataset_num_with_missing": ["col3"]}, + ) + + +def test_move_single_column_to_x(): + adata = read_csv(dataset_path=f"{_TEST_PATH}/dataset_basic.csv", columns_x_only=["b12_values"]) + assert adata.X.shape == (5, 1) + assert list(adata.var_names) == ["b12_values"] + assert "b12_values" not in list(adata.obs.columns) + assert all(obs_names in list(adata.obs.columns) for obs_names in ["los_days", "patient_id", "survival"]) + + +def test_move_multiple_columns_to_x(): + adata = read_csv(dataset_path=f"{_TEST_PATH}/dataset_basic.csv", columns_x_only=["b12_values", "survival"]) + assert adata.X.shape == (5, 2) + assert all(var_names in list(adata.var_names) for var_names in ["b12_values", "survival"]) + assert all(obs_names in list(adata.obs.columns) for obs_names in ["los_days", "patient_id"]) + assert all(var_names not in list(adata.obs.columns) for var_names in ["b12_values", "survival"]) + + +def test_read_multiple_csv_with_x_only(): + adatas = read_csv( + dataset_path=f"{_TEST_PATH}", + columns_x_only={"dataset_non_num_with_missing": ["strcol"], "dataset_num_with_missing": ["col1"]}, + ) + adata_ids = set(adatas.keys()) + assert all(adata_id in adata_ids for adata_id in {"dataset_non_num_with_missing", "dataset_num_with_missing"}) + assert set(adatas["dataset_non_num_with_missing"].obs.columns) == { + "indexcol", + "intcol", + "boolcol", + "binary_col", + "datetime", + } + assert set(adatas["dataset_num_with_missing"].obs.columns) == {"col" + str(i) for i in range(2, 4)} + assert set(adatas["dataset_non_num_with_missing"].var_names) == {"strcol"} + assert set(adatas["dataset_num_with_missing"].var_names) == {"col1"} + + +def test_read_multiple_csv_with_x_only_2(): + adatas = read_csv( + dataset_path=f"{_TEST_PATH}", + columns_x_only={ + "dataset_non_num_with_missing": ["strcol", "intcol", "boolcol"], + "dataset_num_with_missing": ["col1", "col3"], + }, + ) + adata_ids = set(adatas.keys()) + assert all(adata_id in adata_ids for adata_id in {"dataset_non_num_with_missing", "dataset_num_with_missing"}) + assert set(adatas["dataset_non_num_with_missing"].obs.columns) == {"indexcol", "binary_col", "datetime"} + assert set(adatas["dataset_num_with_missing"].obs.columns) == {"col2"} + assert set(adatas["dataset_non_num_with_missing"].var_names) == {"strcol", "intcol", "boolcol"} + assert set(adatas["dataset_num_with_missing"].var_names) == {"col1", "col3"} + + +def test_read_fhir_json(): + adata = read_fhir(_TEST_PATH_FHIR) + + assert adata.shape == (4928, 80) + assert "resource.birthDate" in adata.obs.columns + + +def test_read_fhir_json_obs_only(): + adata = read_fhir(_TEST_PATH_FHIR, columns_obs_only=["fullUrl"]) + + assert adata.shape == (4928, 79) + assert "fullUrl" in adata.obs.columns