|
a |
|
b/utils/data_loader.py |
|
|
1 |
# -*- coding: utf-8 -*- |
|
|
2 |
|
|
|
3 |
import numpy as np |
|
|
4 |
from utils import format_filename |
|
|
5 |
from config import PROCESSED_DATA_DIR, TRAIN_DATA_TEMPLATE, DEV_DATA_TEMPLATE, TEST_DATA_TEMPLATE |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
def load_data(dataset: str, data_type: str): |
|
|
9 |
if data_type == 'train': |
|
|
10 |
return np.load(format_filename(PROCESSED_DATA_DIR, TRAIN_DATA_TEMPLATE, dataset=dataset)) |
|
|
11 |
elif data_type == 'dev': |
|
|
12 |
return np.load(format_filename(PROCESSED_DATA_DIR, DEV_DATA_TEMPLATE, dataset=dataset)) |
|
|
13 |
elif data_type == 'test': |
|
|
14 |
return np.load(format_filename(PROCESSED_DATA_DIR, TEST_DATA_TEMPLATE, dataset=dataset)) |
|
|
15 |
else: |
|
|
16 |
raise ValueError('`data_type` not understood: {data_type}') |