Switch to unified view

a b/Stats/_GradientBoostingClassifier.py
1
#!/usr/bin/env python
2
# -*- coding: UTF-8 -*-
3
#
4
# Copyright 2017 University of Westminster. All Rights Reserved.
5
#
6
# Licensed under the Apache License, Version 2.0 (the "License");
7
# you may not use this file except in compliance with the License.
8
# You may obtain a copy of the License at
9
#
10
#     http://www.apache.org/licenses/LICENSE-2.0
11
#
12
# Unless required by applicable law or agreed to in writing, software
13
# distributed under the License is distributed on an "AS IS" BASIS,
14
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
# See the License for the specific language governing permissions and
16
# limitations under the License.
17
# ==============================================================================
18
""" It is an interface for the 'GradientBoostingClassifier' training model (Gradient Boosting Classifier).
19
"""
20
21
from typing import Dict, List, Any, TypeVar
22
from Stats.Stats import Stats
23
from sklearn import ensemble
24
25
PandasDataFrame = TypeVar('DataFrame')
26
SklearnGradientBoostingClassifier = TypeVar('GradientBoostingClassifier')
27
28
__author__ = "Mohsen Mesgarpour"
29
__copyright__ = "Copyright 2016, https://github.com/mesgarpour"
30
__credits__ = ["Mohsen Mesgarpour"]
31
__license__ = "GPL"
32
__version__ = "1.1"
33
__maintainer__ = "Mohsen Mesgarpour"
34
__email__ = "mohsen.mesgarpour@gmail.com"
35
__status__ = "Release"
36
37
38
class _GradientBoostingClassifier(Stats):
39
    def __init__(self):
40
        """Initialise the objects and constants.
41
        """
42
        super(self.__class__, self).__init__()
43
        self._logger.debug(__name__)
44
        self._logger.debug("Run Gradient Boosting Classifier.")
45
46
    def train(self,
47
              features_indep_df: PandasDataFrame,
48
              feature_target: List,
49
              model_labals: List=[0, 1],
50
              **kwargs: Any) -> SklearnGradientBoostingClassifier:
51
        """Perform the training, using the Gradient Boosting Classifier.
52
        :param features_indep_df: the independent features, which are inputted into the model.
53
        :param feature_target: the target feature, which is being estimated.
54
        :param model_labals: the target labels (default [0, 1]).
55
        :param kwargs: loss='deviance', learning_rate=0.1, n_estimators=100, subsample=1.0, min_samples_split=30,
56
        min_samples_leaf=30, min_weight_fraction_leaf=0.0, max_depth=3, init=None, random_state=None,
57
        max_features=None, verbose=0, max_leaf_nodes=None, warm_start=False, presort='auto'
58
        :return: the trained model.
59
        """
60
        self._logger.debug("Train " + __name__)
61
        model_train = ensemble.GradientBoostingClassifier(**kwargs)
62
        model_train.fit(features_indep_df.values, feature_target)
63
        return model_train
64
65
    def train_summaries(self,
66
                        model_train: SklearnGradientBoostingClassifier) -> Dict:
67
        """Produce the training summary.
68
        :param model_train: the instance of the trained model.
69
        :return: the training summary.
70
        """
71
        self._logger.debug("Summarise " + __name__)
72
        summaries = dict()
73
        summaries['feature_importances_'] = model_train.feature_importances_
74
        summaries['train_score_'] = model_train.train_score_
75
        summaries['loss_'] = model_train.loss_
76
        summaries['init'] = model_train.init
77
        summaries['estimators_'] = model_train.estimators_
78
        return summaries
79
80
    def plot(self,
81
             model_train: SklearnGradientBoostingClassifier,
82
             feature_names: List,
83
             class_names: List=["True", "False"]):
84
        """Plot the tree diagram.
85
        :param model_train: the instance of the trained model.
86
        :param feature_names: the names of input features.
87
        :param class_names: the predicted class labels.
88
        :return: the model graph.
89
        """
90
        self._logger.debug("Plot " + __name__)
91
        # todo: plot
92
        pass