|
a |
|
b/simdeep/simdeep_utils.py |
|
|
1 |
from simdeep.config import PATH_TO_SAVE_MODEL |
|
|
2 |
|
|
|
3 |
from os.path import isfile |
|
|
4 |
from os.path import isdir |
|
|
5 |
|
|
|
6 |
from os import mkdir |
|
|
7 |
|
|
|
8 |
# from sys import version_info |
|
|
9 |
|
|
|
10 |
# if version_info > (3, 0, 0): |
|
|
11 |
# import pickle as cPickle |
|
|
12 |
# else: |
|
|
13 |
# import cPickle |
|
|
14 |
|
|
|
15 |
import dill |
|
|
16 |
|
|
|
17 |
from time import time |
|
|
18 |
|
|
|
19 |
def save_model(boosting, path_to_save_model=PATH_TO_SAVE_MODEL): |
|
|
20 |
""" """ |
|
|
21 |
if not isdir(path_to_save_model): |
|
|
22 |
mkdir(path_to_save_model) |
|
|
23 |
|
|
|
24 |
boosting._convert_logs() |
|
|
25 |
|
|
|
26 |
t = time() |
|
|
27 |
|
|
|
28 |
with open('{0}/{1}.pickle'.format( |
|
|
29 |
path_to_save_model, |
|
|
30 |
boosting._project_name), 'wb') as f_pick: |
|
|
31 |
dill.dump(boosting, f_pick) |
|
|
32 |
|
|
|
33 |
print('model saved in %2.1f s at %s/%s.pickle' % ( |
|
|
34 |
time() - t, path_to_save_model, boosting._project_name)) |
|
|
35 |
|
|
|
36 |
|
|
|
37 |
def load_model(project_name, path_model=PATH_TO_SAVE_MODEL): |
|
|
38 |
""" """ |
|
|
39 |
t = time() |
|
|
40 |
project_name = project_name.replace('.pickle', '') + '.pickle' |
|
|
41 |
|
|
|
42 |
assert(isfile('{0}/{1}'.format(path_model, project_name))) |
|
|
43 |
|
|
|
44 |
with open('{0}/{1}'.format(path_model, project_name), 'rb') as f_pick: |
|
|
45 |
boosting = dill.load(f_pick) |
|
|
46 |
|
|
|
47 |
print('model loaded in %2.1f s' % (time() - t)) |
|
|
48 |
|
|
|
49 |
return boosting |
|
|
50 |
|
|
|
51 |
|
|
|
52 |
def metadata_usage_type(value): |
|
|
53 |
""" """ |
|
|
54 |
if value not in {None, |
|
|
55 |
False, |
|
|
56 |
'labels', |
|
|
57 |
'new-features', |
|
|
58 |
'test-labels', |
|
|
59 |
'all', True}: |
|
|
60 |
raise Exception( |
|
|
61 |
"metadata_usage_type: {0} should be from the following choices:" \ |
|
|
62 |
" [None, False, 'labels', 'new-features', 'all', True]" \ |
|
|
63 |
.format(value)) |
|
|
64 |
|
|
|
65 |
if value == True: |
|
|
66 |
return 'all' |
|
|
67 |
|
|
|
68 |
return value |
|
|
69 |
|
|
|
70 |
|
|
|
71 |
def feature_selection_usage_type(value): |
|
|
72 |
""" """ |
|
|
73 |
if value not in {'individual', |
|
|
74 |
'lasso', |
|
|
75 |
None}: |
|
|
76 |
raise Exception( |
|
|
77 |
"feature_selection_usage_type: {0} should be from the following choices:" \ |
|
|
78 |
" ['individual', 'lasso', None]" \ |
|
|
79 |
.format(value)) |
|
|
80 |
|
|
|
81 |
return value |
|
|
82 |
|
|
|
83 |
|
|
|
84 |
def load_labels_file(path_labels, sep="\t"): |
|
|
85 |
""" |
|
|
86 |
""" |
|
|
87 |
labels_dict = {} |
|
|
88 |
|
|
|
89 |
for line in open(path_labels): |
|
|
90 |
split = line.strip().split(sep) |
|
|
91 |
|
|
|
92 |
if len(split) < 2: |
|
|
93 |
raise Exception( |
|
|
94 |
'## Errorfor file in load_labels_file: {0} for line{1}' \ |
|
|
95 |
' line cannot be splitted in more than 2'.format( |
|
|
96 |
line, path_labels)) |
|
|
97 |
|
|
|
98 |
patient, label = split[0], split[1] |
|
|
99 |
|
|
|
100 |
try: |
|
|
101 |
label = int(float(label)) |
|
|
102 |
except Exception: |
|
|
103 |
raise Exception( |
|
|
104 |
'## Error: in load_labels_file {0} for line {1}' \ |
|
|
105 |
'labels should be an int'.format( |
|
|
106 |
path_labels, line)) |
|
|
107 |
|
|
|
108 |
if len(split) > 2: |
|
|
109 |
try: |
|
|
110 |
proba = float(split[2]) |
|
|
111 |
except Exception: |
|
|
112 |
raise Exception( |
|
|
113 |
'## Error: in load_labels_file {0} for line {1}' \ |
|
|
114 |
'label proba in column 3 should be a float'.format( |
|
|
115 |
path_labels, line)) |
|
|
116 |
else: |
|
|
117 |
proba = label |
|
|
118 |
|
|
|
119 |
labels_dict[patient] = (label, proba) |
|
|
120 |
|
|
|
121 |
return labels_dict |