a b/Stats/_LogisticRegressionCV.py
1
#!/usr/bin/env python
2
# -*- coding: UTF-8 -*-
3
#
4
# Copyright 2017 University of Westminster. All Rights Reserved.
5
#
6
# Licensed under the Apache License, Version 2.0 (the "License");
7
# you may not use this file except in compliance with the License.
8
# You may obtain a copy of the License at
9
#
10
#     http://www.apache.org/licenses/LICENSE-2.0
11
#
12
# Unless required by applicable law or agreed to in writing, software
13
# distributed under the License is distributed on an "AS IS" BASIS,
14
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
# See the License for the specific language governing permissions and
16
# limitations under the License.
17
# ==============================================================================
18
""" It is an interface for the 'LogisticRegressionCV' training model (Logistic Regression with Cross-Validation).
19
"""
20
21
from typing import Dict, List, Any, TypeVar
22
from Stats.Stats import Stats
23
from sklearn import linear_model
24
25
PandasDataFrame = TypeVar('DataFrame')
26
SklearnLogisticRegressionCV = TypeVar('LogisticRegressionCV')
27
28
__author__ = "Mohsen Mesgarpour"
29
__copyright__ = "Copyright 2016, https://github.com/mesgarpour"
30
__credits__ = ["Mohsen Mesgarpour"]
31
__license__ = "GPL"
32
__version__ = "1.1"
33
__maintainer__ = "Mohsen Mesgarpour"
34
__email__ = "mohsen.mesgarpour@gmail.com"
35
__status__ = "Release"
36
37
38
class _LogisticRegressionCV(Stats):
39
    def __init__(self):
40
        """Initialise the objects and constants.
41
        """
42
        super(self.__class__, self).__init__()
43
        self._logger.debug(__name__)
44
        self._logger.debug("Run Logistic Regression with Cross-Validation.")
45
46
    def train(self,
47
              features_indep_df: PandasDataFrame,
48
              feature_target: List,
49
              model_labals: List=[0, 1],
50
              **kwargs: Any) -> SklearnLogisticRegressionCV:
51
        """Perform the training, using the Logistic Regression with Cross-Validation.
52
        :param features_indep_df: the independent features, which are inputted into the model.
53
        :param feature_target: the target feature, which is being estimated.
54
        :param model_labals: the target labels (default [0, 1]).
55
        :param kwargs: Cs=10, fit_intercept=True, cv=None, dual=False, penalty='l2', scoring=None, solver='lbfgs',
56
        tol=0.0001, max_iter=100, class_weight=None, n_jobs=-1, verbose=0, refit=True, intercept_scaling=1.0,
57
        multi_class='ovr', random_state=None
58
        :return: the trained model.
59
        """
60
        self._logger.debug("Train " + __name__)
61
        model_train = linear_model.LogisticRegressionCV(**kwargs)
62
        model_train.fit(features_indep_df, feature_target)
63
        return model_train
64
65
    def train_summaries(self,
66
                        model_train: SklearnLogisticRegressionCV) -> Dict:
67
        """Produce the training summary.
68
        :param model_train: the instance of the trained model.
69
        :return: the training summary.
70
        """
71
        self._logger.debug("Summarise " + __name__)
72
        summaries = dict()
73
        # Coefficient of the features in the decision function.
74
        summaries["coef_"] = model_train.coef_
75
        # Intercept (a.k.a. bias) added to the decision function.
76
        summaries["intercept_"] = model_train.intercept_
77
        # Actual number of iterations for all classes. If binary or multinomial, it returns only 1 element.
78
        summaries["n_iter_"] = model_train.n_iter_
79
        return summaries
80
81
    def plot(self,
82
             model_train: SklearnLogisticRegressionCV,
83
             feature_names: List,
84
             class_names: List=["True", "False"]):
85
        """Plot the tree diagram.
86
        :param model_train: the instance of the trained model.
87
        :param feature_names: the names of input features.
88
        :param class_names: the predicted class labels.
89
        :return: the model graph.
90
        """
91
        self._logger.debug("Plot " + __name__)
92
        # todo: plot
93
        pass