--- a +++ b/src/ovr/MLmodels.py @@ -0,0 +1,39 @@ +# multilabel classfication +from sklearn.multiclass import OneVsRestClassifier + +# Linear Models +from sklearn.linear_model import LogisticRegression +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis +from sklearn.svm import LinearSVC +from sklearn.naive_bayes import MultinomialNB +from lightgbm import LGBMClassifier + +def train_classifier(X_train, y_train, X_valid=None, y_valid=None, C=1.0, model='lr'): + """ + X_train, y_train — training data + + return: trained classifier + + """ + + if model=='lr': + model = LogisticRegression(C=C, penalty='l1', solver='liblinear') + model = OneVsRestClassifier(model) + model.fit(X_train, y_train) + + elif model=='svm': + model = LinearSVC(C=C, penalty='l1', loss='squared_hinge') + model = OneVsRestClassifier(model) + model.fit(X_train, y_train) + + elif model=='nbayes': + model = MultinomialNB(alpha=1.0) + model = OneVsRestClassifier(model) + model.fit(X_train, y_train) + + elif model=='lda': + model = LinearDiscriminantAnalysis(solver='svd') + model = OneVsRestClassifier(model) + model.fit(X_train, y_train) + + return model \ No newline at end of file