[40f229]: / model / initialization.py

Download this file

58 lines (49 with data), 1.8 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# -*- coding: utf-8 -*-
# @Author : admin
# @Time : 2018/11/15
import os
from copy import deepcopy
import numpy as np
from .utils import load_data
from .model import Model
def initialize_data(config, train=False, test=False):
print("Initializing data source...")
train_source, test_source = load_data(**config['data'], cache=(train or test))
if train:
print("Loading training data...")
train_source.load_all_data()
if test:
print("Loading test data...")
test_source.load_all_data()
print("Data initialization complete.")
return train_source, test_source
def initialize_model(config, train_source, test_source):
print("Initializing model...")
data_config = config['data']
model_config = config['model']
model_param = deepcopy(model_config)
model_param['train_source'] = train_source
model_param['test_source'] = test_source
model_param['train_pid_num'] = data_config['pid_num']
batch_size = int(np.prod(model_config['batch_size']))
model_param['save_name'] = '_'.join(map(str,[
model_config['model_name'],
data_config['dataset'],
data_config['pid_num'],
data_config['pid_shuffle'],
model_config['hidden_dim'],
model_config['margin'],
batch_size,
model_config['hard_or_full_trip'],
model_config['frame_num'],
]))
m = Model(**model_param)
print("Model initialization complete.")
return m, model_param['save_name']
def initialization(config, train=False, test=False):
print("Initialzing...")
WORK_PATH = config['WORK_PATH']
os.chdir(WORK_PATH)
os.environ["CUDA_VISIBLE_DEVICES"] = config["CUDA_VISIBLE_DEVICES"]
train_source, test_source = initialize_data(config, train, test)
return initialize_model(config, train_source, test_source)