[70b42d]: / src / modeling.py

Download this file

22 lines (19 with data), 819 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.svm import SVC
def get_classification_model(model_type='logistic_regression'):
if model_type == 'logistic_regression':
return LogisticRegression()
elif model_type == 'random_forest':
return RandomForestClassifier()
elif model_type == 'svm':
return SVC()
else:
raise ValueError("Unsupported model type for classification")
def get_regression_model(model_type='linear_regression'):
if model_type == 'linear_regression':
return LinearRegression()
elif model_type == 'random_forest':
return RandomForestRegressor()
else:
raise ValueError("Unsupported model type for regression")