[71ad2f]: / src / ovr / MLmodels.py

Download this file

39 lines (30 with data), 1.2 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# multilabel classfication
from sklearn.multiclass import OneVsRestClassifier
# Linear Models
from sklearn.linear_model import LogisticRegression
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.svm import LinearSVC
from sklearn.naive_bayes import MultinomialNB
from lightgbm import LGBMClassifier
def train_classifier(X_train, y_train, X_valid=None, y_valid=None, C=1.0, model='lr'):
"""
X_train, y_train — training data
return: trained classifier
"""
if model=='lr':
model = LogisticRegression(C=C, penalty='l1', solver='liblinear')
model = OneVsRestClassifier(model)
model.fit(X_train, y_train)
elif model=='svm':
model = LinearSVC(C=C, penalty='l1', loss='squared_hinge')
model = OneVsRestClassifier(model)
model.fit(X_train, y_train)
elif model=='nbayes':
model = MultinomialNB(alpha=1.0)
model = OneVsRestClassifier(model)
model.fit(X_train, y_train)
elif model=='lda':
model = LinearDiscriminantAnalysis(solver='svd')
model = OneVsRestClassifier(model)
model.fit(X_train, y_train)
return model