|
a |
|
b/Stats/_LogisticRegressionCV.py |
|
|
1 |
#!/usr/bin/env python |
|
|
2 |
# -*- coding: UTF-8 -*- |
|
|
3 |
# |
|
|
4 |
# Copyright 2017 University of Westminster. All Rights Reserved. |
|
|
5 |
# |
|
|
6 |
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
7 |
# you may not use this file except in compliance with the License. |
|
|
8 |
# You may obtain a copy of the License at |
|
|
9 |
# |
|
|
10 |
# http://www.apache.org/licenses/LICENSE-2.0 |
|
|
11 |
# |
|
|
12 |
# Unless required by applicable law or agreed to in writing, software |
|
|
13 |
# distributed under the License is distributed on an "AS IS" BASIS, |
|
|
14 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
15 |
# See the License for the specific language governing permissions and |
|
|
16 |
# limitations under the License. |
|
|
17 |
# ============================================================================== |
|
|
18 |
""" It is an interface for the 'LogisticRegressionCV' training model (Logistic Regression with Cross-Validation). |
|
|
19 |
""" |
|
|
20 |
|
|
|
21 |
from typing import Dict, List, Any, TypeVar |
|
|
22 |
from Stats.Stats import Stats |
|
|
23 |
from sklearn import linear_model |
|
|
24 |
|
|
|
25 |
PandasDataFrame = TypeVar('DataFrame') |
|
|
26 |
SklearnLogisticRegressionCV = TypeVar('LogisticRegressionCV') |
|
|
27 |
|
|
|
28 |
__author__ = "Mohsen Mesgarpour" |
|
|
29 |
__copyright__ = "Copyright 2016, https://github.com/mesgarpour" |
|
|
30 |
__credits__ = ["Mohsen Mesgarpour"] |
|
|
31 |
__license__ = "GPL" |
|
|
32 |
__version__ = "1.1" |
|
|
33 |
__maintainer__ = "Mohsen Mesgarpour" |
|
|
34 |
__email__ = "mohsen.mesgarpour@gmail.com" |
|
|
35 |
__status__ = "Release" |
|
|
36 |
|
|
|
37 |
|
|
|
38 |
class _LogisticRegressionCV(Stats): |
|
|
39 |
def __init__(self): |
|
|
40 |
"""Initialise the objects and constants. |
|
|
41 |
""" |
|
|
42 |
super(self.__class__, self).__init__() |
|
|
43 |
self._logger.debug(__name__) |
|
|
44 |
self._logger.debug("Run Logistic Regression with Cross-Validation.") |
|
|
45 |
|
|
|
46 |
def train(self, |
|
|
47 |
features_indep_df: PandasDataFrame, |
|
|
48 |
feature_target: List, |
|
|
49 |
model_labals: List=[0, 1], |
|
|
50 |
**kwargs: Any) -> SklearnLogisticRegressionCV: |
|
|
51 |
"""Perform the training, using the Logistic Regression with Cross-Validation. |
|
|
52 |
:param features_indep_df: the independent features, which are inputted into the model. |
|
|
53 |
:param feature_target: the target feature, which is being estimated. |
|
|
54 |
:param model_labals: the target labels (default [0, 1]). |
|
|
55 |
:param kwargs: Cs=10, fit_intercept=True, cv=None, dual=False, penalty='l2', scoring=None, solver='lbfgs', |
|
|
56 |
tol=0.0001, max_iter=100, class_weight=None, n_jobs=-1, verbose=0, refit=True, intercept_scaling=1.0, |
|
|
57 |
multi_class='ovr', random_state=None |
|
|
58 |
:return: the trained model. |
|
|
59 |
""" |
|
|
60 |
self._logger.debug("Train " + __name__) |
|
|
61 |
model_train = linear_model.LogisticRegressionCV(**kwargs) |
|
|
62 |
model_train.fit(features_indep_df, feature_target) |
|
|
63 |
return model_train |
|
|
64 |
|
|
|
65 |
def train_summaries(self, |
|
|
66 |
model_train: SklearnLogisticRegressionCV) -> Dict: |
|
|
67 |
"""Produce the training summary. |
|
|
68 |
:param model_train: the instance of the trained model. |
|
|
69 |
:return: the training summary. |
|
|
70 |
""" |
|
|
71 |
self._logger.debug("Summarise " + __name__) |
|
|
72 |
summaries = dict() |
|
|
73 |
# Coefficient of the features in the decision function. |
|
|
74 |
summaries["coef_"] = model_train.coef_ |
|
|
75 |
# Intercept (a.k.a. bias) added to the decision function. |
|
|
76 |
summaries["intercept_"] = model_train.intercept_ |
|
|
77 |
# Actual number of iterations for all classes. If binary or multinomial, it returns only 1 element. |
|
|
78 |
summaries["n_iter_"] = model_train.n_iter_ |
|
|
79 |
return summaries |
|
|
80 |
|
|
|
81 |
def plot(self, |
|
|
82 |
model_train: SklearnLogisticRegressionCV, |
|
|
83 |
feature_names: List, |
|
|
84 |
class_names: List=["True", "False"]): |
|
|
85 |
"""Plot the tree diagram. |
|
|
86 |
:param model_train: the instance of the trained model. |
|
|
87 |
:param feature_names: the names of input features. |
|
|
88 |
:param class_names: the predicted class labels. |
|
|
89 |
:return: the model graph. |
|
|
90 |
""" |
|
|
91 |
self._logger.debug("Plot " + __name__) |
|
|
92 |
# todo: plot |
|
|
93 |
pass |