|
a |
|
b/src/ovr/MLmodels.py |
|
|
1 |
# multilabel classfication |
|
|
2 |
from sklearn.multiclass import OneVsRestClassifier |
|
|
3 |
|
|
|
4 |
# Linear Models |
|
|
5 |
from sklearn.linear_model import LogisticRegression |
|
|
6 |
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis |
|
|
7 |
from sklearn.svm import LinearSVC |
|
|
8 |
from sklearn.naive_bayes import MultinomialNB |
|
|
9 |
from lightgbm import LGBMClassifier |
|
|
10 |
|
|
|
11 |
def train_classifier(X_train, y_train, X_valid=None, y_valid=None, C=1.0, model='lr'): |
|
|
12 |
""" |
|
|
13 |
X_train, y_train — training data |
|
|
14 |
|
|
|
15 |
return: trained classifier |
|
|
16 |
|
|
|
17 |
""" |
|
|
18 |
|
|
|
19 |
if model=='lr': |
|
|
20 |
model = LogisticRegression(C=C, penalty='l1', solver='liblinear') |
|
|
21 |
model = OneVsRestClassifier(model) |
|
|
22 |
model.fit(X_train, y_train) |
|
|
23 |
|
|
|
24 |
elif model=='svm': |
|
|
25 |
model = LinearSVC(C=C, penalty='l1', loss='squared_hinge') |
|
|
26 |
model = OneVsRestClassifier(model) |
|
|
27 |
model.fit(X_train, y_train) |
|
|
28 |
|
|
|
29 |
elif model=='nbayes': |
|
|
30 |
model = MultinomialNB(alpha=1.0) |
|
|
31 |
model = OneVsRestClassifier(model) |
|
|
32 |
model.fit(X_train, y_train) |
|
|
33 |
|
|
|
34 |
elif model=='lda': |
|
|
35 |
model = LinearDiscriminantAnalysis(solver='svd') |
|
|
36 |
model = OneVsRestClassifier(model) |
|
|
37 |
model.fit(X_train, y_train) |
|
|
38 |
|
|
|
39 |
return model |