Diff of /model/initialization.py [000000] .. [40f229]

Switch to unified view

a b/model/initialization.py
1
# -*- coding: utf-8 -*-
2
# @Author  : admin
3
# @Time    : 2018/11/15
4
import os
5
from copy import deepcopy
6
7
import numpy as np
8
9
from .utils import load_data
10
from .model import Model
11
12
13
def initialize_data(config, train=False, test=False):
14
    print("Initializing data source...")
15
    train_source, test_source = load_data(**config['data'], cache=(train or test))
16
    if train:
17
        print("Loading training data...")
18
        train_source.load_all_data()
19
    if test:
20
        print("Loading test data...")
21
        test_source.load_all_data()
22
    print("Data initialization complete.")
23
    return train_source, test_source
24
25
26
def initialize_model(config, train_source, test_source):
27
    print("Initializing model...")
28
    data_config = config['data']
29
    model_config = config['model']
30
    model_param = deepcopy(model_config)
31
    model_param['train_source'] = train_source
32
    model_param['test_source'] = test_source
33
    model_param['train_pid_num'] = data_config['pid_num']
34
    batch_size = int(np.prod(model_config['batch_size']))
35
    model_param['save_name'] = '_'.join(map(str,[
36
        model_config['model_name'],
37
        data_config['dataset'],
38
        data_config['pid_num'],
39
        data_config['pid_shuffle'],
40
        model_config['hidden_dim'],
41
        model_config['margin'],
42
        batch_size,
43
        model_config['hard_or_full_trip'],
44
        model_config['frame_num'],
45
    ]))
46
47
    m = Model(**model_param)
48
    print("Model initialization complete.")
49
    return m, model_param['save_name']
50
51
52
def initialization(config, train=False, test=False):
53
    print("Initialzing...")
54
    WORK_PATH = config['WORK_PATH']
55
    os.chdir(WORK_PATH)
56
    os.environ["CUDA_VISIBLE_DEVICES"] = config["CUDA_VISIBLE_DEVICES"]
57
    train_source, test_source = initialize_data(config, train, test)
58
    return initialize_model(config, train_source, test_source)