[21b321]: / utils / data_loader.py

Download this file

17 lines (13 with data), 680 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
# -*- coding: utf-8 -*-
import numpy as np
from utils import format_filename
from config import PROCESSED_DATA_DIR, TRAIN_DATA_TEMPLATE, DEV_DATA_TEMPLATE, TEST_DATA_TEMPLATE
def load_data(dataset: str, data_type: str):
if data_type == 'train':
return np.load(format_filename(PROCESSED_DATA_DIR, TRAIN_DATA_TEMPLATE, dataset=dataset))
elif data_type == 'dev':
return np.load(format_filename(PROCESSED_DATA_DIR, DEV_DATA_TEMPLATE, dataset=dataset))
elif data_type == 'test':
return np.load(format_filename(PROCESSED_DATA_DIR, TEST_DATA_TEMPLATE, dataset=dataset))
else:
raise ValueError('`data_type` not understood: {data_type}')