a b/simdeep/simdeep_utils.py
1
from simdeep.config import PATH_TO_SAVE_MODEL
2
3
from os.path import isfile
4
from os.path import isdir
5
6
from os import mkdir
7
8
# from sys import version_info
9
10
# if version_info > (3, 0, 0):
11
#     import pickle as cPickle
12
# else:
13
#     import cPickle
14
15
import dill
16
17
from time import time
18
19
def save_model(boosting, path_to_save_model=PATH_TO_SAVE_MODEL):
20
    """ """
21
    if not isdir(path_to_save_model):
22
        mkdir(path_to_save_model)
23
24
    boosting._convert_logs()
25
26
    t = time()
27
28
    with open('{0}/{1}.pickle'.format(
29
            path_to_save_model,
30
            boosting._project_name), 'wb') as f_pick:
31
        dill.dump(boosting, f_pick)
32
33
    print('model saved in %2.1f s at %s/%s.pickle' % (
34
        time() - t, path_to_save_model, boosting._project_name))
35
36
37
def load_model(project_name, path_model=PATH_TO_SAVE_MODEL):
38
    """ """
39
    t = time()
40
    project_name = project_name.replace('.pickle', '') + '.pickle'
41
42
    assert(isfile('{0}/{1}'.format(path_model, project_name)))
43
44
    with open('{0}/{1}'.format(path_model, project_name), 'rb') as f_pick:
45
        boosting = dill.load(f_pick)
46
47
    print('model loaded in %2.1f s' % (time() - t))
48
49
    return boosting
50
51
52
def metadata_usage_type(value):
53
    """ """
54
    if value not in {None,
55
                     False,
56
                     'labels',
57
                     'new-features',
58
                     'test-labels',
59
                     'all', True}:
60
        raise Exception(
61
            "metadata_usage_type: {0} should be from the following choices:" \
62
            " [None, False, 'labels', 'new-features', 'all', True]" \
63
            .format(value))
64
65
    if value == True:
66
        return 'all'
67
68
    return value
69
70
71
def feature_selection_usage_type(value):
72
    """ """
73
    if value not in {'individual',
74
                     'lasso',
75
                     None}:
76
        raise Exception(
77
            "feature_selection_usage_type: {0} should be from the following choices:" \
78
            " ['individual', 'lasso', None]" \
79
            .format(value))
80
81
    return value
82
83
84
def load_labels_file(path_labels, sep="\t"):
85
    """
86
    """
87
    labels_dict = {}
88
89
    for line in open(path_labels):
90
        split = line.strip().split(sep)
91
92
        if len(split) < 2:
93
            raise Exception(
94
                '## Errorfor file in load_labels_file: {0} for line{1}' \
95
                ' line cannot be splitted in more than 2'.format(
96
                    line, path_labels))
97
98
        patient, label = split[0], split[1]
99
100
        try:
101
            label = int(float(label))
102
        except Exception:
103
            raise Exception(
104
                '## Error: in load_labels_file {0} for line {1}' \
105
                'labels should be an int'.format(
106
                    path_labels, line))
107
108
        if len(split) > 2:
109
            try:
110
                proba = float(split[2])
111
            except Exception:
112
                raise Exception(
113
                    '## Error: in load_labels_file {0} for line {1}' \
114
                    'label proba in column 3 should be a float'.format(
115
                        path_labels, line))
116
            else:
117
                proba = label
118
119
        labels_dict[patient] = (label, proba)
120
121
    return labels_dict