Diff of /src/ovr/MLmodels.py [000000] .. [71ad2f]

Switch to unified view

a b/src/ovr/MLmodels.py
1
# multilabel classfication
2
from sklearn.multiclass import OneVsRestClassifier
3
4
# Linear Models
5
from sklearn.linear_model import LogisticRegression
6
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
7
from sklearn.svm import LinearSVC
8
from sklearn.naive_bayes import MultinomialNB
9
from lightgbm import LGBMClassifier
10
11
def train_classifier(X_train, y_train, X_valid=None, y_valid=None, C=1.0, model='lr'):
12
    """
13
      X_train, y_train — training data
14
      
15
      return: trained classifier
16
      
17
    """
18
    
19
    if model=='lr':
20
        model = LogisticRegression(C=C, penalty='l1', solver='liblinear')
21
        model = OneVsRestClassifier(model)
22
        model.fit(X_train, y_train)
23
    
24
    elif model=='svm':
25
        model = LinearSVC(C=C, penalty='l1', loss='squared_hinge')
26
        model = OneVsRestClassifier(model)
27
        model.fit(X_train, y_train)
28
    
29
    elif model=='nbayes':
30
        model = MultinomialNB(alpha=1.0)
31
        model = OneVsRestClassifier(model)
32
        model.fit(X_train, y_train)
33
        
34
    elif model=='lda':
35
        model = LinearDiscriminantAnalysis(solver='svd')
36
        model = OneVsRestClassifier(model)
37
        model.fit(X_train, y_train)
38
39
    return model