# multilabel classfication
from sklearn.multiclass import OneVsRestClassifier
# Linear Models
from sklearn.linear_model import LogisticRegression
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.svm import LinearSVC
from sklearn.naive_bayes import MultinomialNB
from lightgbm import LGBMClassifier
def train_classifier(X_train, y_train, X_valid=None, y_valid=None, C=1.0, model='lr'):
"""
X_train, y_train — training data
return: trained classifier
"""
if model=='lr':
model = LogisticRegression(C=C, penalty='l1', solver='liblinear')
model = OneVsRestClassifier(model)
model.fit(X_train, y_train)
elif model=='svm':
model = LinearSVC(C=C, penalty='l1', loss='squared_hinge')
model = OneVsRestClassifier(model)
model.fit(X_train, y_train)
elif model=='nbayes':
model = MultinomialNB(alpha=1.0)
model = OneVsRestClassifier(model)
model.fit(X_train, y_train)
elif model=='lda':
model = LinearDiscriminantAnalysis(solver='svd')
model = OneVsRestClassifier(model)
model.fit(X_train, y_train)
return model