a b/src/move/data/io.py
1
__all__ = [
2
    "dump_names",
3
    "dump_mappings",
4
    "load_mappings",
5
    "load_preprocessed_data",
6
    "read_config",
7
    "read_names",
8
    "read_tsv",
9
]
10
11
import json
12
from pathlib import Path
13
from typing import Optional
14
15
import hydra
16
import numpy as np
17
import pandas as pd
18
from omegaconf import DictConfig
19
20
from move import HYDRA_VERSION_BASE, conf
21
from move.core.typing import BoolArray, FloatArray, ObjectArray, PathLike
22
23
24
def read_config(
25
    data_config_name: Optional[str], task_config_name: Optional[str], *args
26
) -> DictConfig:
27
    """Composes configuration for the MOVE framework.
28
29
    Args:
30
        data_config_name: Name of data configuration file
31
        task_config_name: Name of task configuration file
32
        *args: Additional overrides
33
34
    Returns:
35
        Merged configuration
36
    """
37
    overrides = []
38
    if data_config_name is not None:
39
        overrides.append(f"data={data_config_name}")
40
    if task_config_name is not None:
41
        overrides.append(f"task={task_config_name}")
42
    overrides.extend(args)
43
    with hydra.initialize_config_module(conf.__name__, version_base=HYDRA_VERSION_BASE):
44
        return hydra.compose("main", overrides=overrides)
45
46
47
def load_categorical_dataset(filepath: PathLike) -> FloatArray:
48
    """Loads categorical data in a NumPy file.
49
50
    Args:
51
        filepath: Path to NumPy file containing a categorical dataset
52
53
    Returns:
54
        NumPy array containing categorical data
55
    """
56
    return np.load(filepath).astype(np.float32)
57
58
59
def load_continuous_dataset(filepath: PathLike) -> tuple[FloatArray, BoolArray]:
60
    """Loads continuous data from a NumPy file and filters out columns
61
    (features) whose sum is zero. Additionally, encodes NaN values as zeros.
62
63
    Args:
64
        filepath: Path to NumPy file containing a continuous dataset
65
66
    Returns:
67
        Tuple containing (1) the NumPy dataset and (2) a mask marking columns
68
        (i.e., features) that were not filtered out
69
    """
70
    data = np.load(filepath).astype(np.float32)
71
    data[np.isnan(data)] = 0
72
    mask_col = np.abs(data).sum(axis=0) != 0
73
    data = data[:, mask_col]
74
    return data, mask_col
75
76
77
def load_preprocessed_data(
78
    path: Path,
79
    categorical_dataset_names: list[str],
80
    continuous_dataset_names: list[str],
81
) -> tuple[list[FloatArray], list[list[str]], list[FloatArray], list[list[str]]]:
82
    """Loads the pre-processed categorical and continuous data.
83
84
    Args:
85
        path: Where the data is saved
86
        categorical_dataset_names: List of names of the categorical datasets
87
        continuous_dataset_names: List of names of the continuous datasets
88
89
    Returns:
90
        Returns two pairs of list containing (1, 3) the pre-processed data and
91
        (2, 4) the lists of names of each feature
92
    """
93
94
    categorical_data, categorical_var_names = [], []
95
    for dataset_name in categorical_dataset_names:
96
        data = load_categorical_dataset(path / f"{dataset_name}.npy")
97
        categorical_data.append(data)
98
        var_names = read_names(path / f"{dataset_name}.txt")
99
        categorical_var_names.append(var_names)
100
101
    continuous_data, continuous_var_names = [], []
102
    for dataset_name in continuous_dataset_names:
103
        data, keep = load_continuous_dataset(path / f"{dataset_name}.npy")
104
        continuous_data.append(data)
105
        var_names = read_names(path / f"{dataset_name}.txt")
106
        var_names = [name for i, name in enumerate(var_names) if keep[i]]
107
        continuous_var_names.append(var_names)
108
109
    return (
110
        categorical_data,
111
        categorical_var_names,
112
        continuous_data,
113
        continuous_var_names,
114
    )
115
116
117
def read_names(path: PathLike) -> list[str]:
118
    """Reads sample names from a text file. The text file should have one line
119
    per sample name.
120
121
    Args:
122
        path: Path to the text file
123
124
    Returns:
125
        A list of sample names
126
    """
127
    with open(path, "r", encoding="utf-8") as file:
128
        return [i.strip() for i in file.readlines()]
129
130
131
def read_tsv(
132
    path: PathLike, sample_names: Optional[list[str]] = None
133
) -> tuple[ObjectArray, np.ndarray]:
134
    """Read a dataset from a TSV file. The TSV is expected to have an index
135
    column (0th index).
136
137
    Args:
138
        path: Path to TSV
139
        index: List of sample names used to sort/filter samples
140
141
    Returns:
142
        Tuple containing (1) feature names and (2) 2D matrix (samples x
143
        features)
144
    """
145
    data = pd.read_csv(path, index_col=0, sep="\t")
146
    if sample_names is not None:
147
        data.index = data.index.astype(str, False)
148
        data = data.loc[sample_names]
149
    return data.columns.values, data.values
150
151
152
def load_mappings(path: PathLike) -> dict[str, dict[str, int]]:
153
    with open(path, "r", encoding="utf-8") as file:
154
        return json.load(file)
155
156
157
def dump_mappings(path: PathLike, mappings: dict[str, dict[str, int]]) -> None:
158
    with open(path, "w", encoding="utf-8") as file:
159
        json.dump(mappings, file, indent=4, ensure_ascii=False)
160
161
162
def dump_names(path: PathLike, names: np.ndarray) -> None:
163
    with open(path, "w", encoding="utf-8") as file:
164
        file.writelines([f"{name}\n" for name in names])