Diff of /src/evaluation.py [000000] .. [70b42d]

Switch to side-by-side view

--- a
+++ b/src/evaluation.py
@@ -0,0 +1,190 @@
+from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, mean_squared_error, r2_score
+import matplotlib.pyplot as plt
+import seaborn as sns
+
+def evaluate_classification_model(pipeline, X_test, y_test):
+    y_pred = pipeline.predict(X_test)
+    accuracy = accuracy_score(y_test, y_pred)
+    precision = precision_score(y_test, y_pred, average='weighted')
+    recall = recall_score(y_test, y_pred, average='weighted')
+    f1 = f1_score(y_test, y_pred, average='weighted')
+    return accuracy, precision, recall, f1
+
+def evaluate_regression_model(pipeline, X_test, y_test):
+    y_pred = pipeline.predict(X_test)
+    mse = mean_squared_error(y_test, y_pred)
+    r2 = r2_score(y_test, y_pred)
+    return mse, r2
+
+# def plot_confusion_matrix(pipeline, X_test, y_test, classes):
+#     y_pred = pipeline.predict(X_test)
+#     cm = confusion_matrix(y_test, y_pred, labels=classes)
+#     plt.figure(figsize=(10, 7))
+#     sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
+#     plt.xlabel('Predicted')
+#     plt.ylabel('Actual')
+#     plt.show()
+
+
+def plot_confusion_matrix(y_true, y_pred, classes, cmap='Blues', figsize=(10, 7)):
+    """
+    Plots a confusion matrix using matplotlib.
+    
+    Parameters:
+    y_true : array-like of shape (n_samples,)
+        True labels.
+    y_pred : array-like of shape (n_samples,)
+        Predicted labels.
+    classes : array-like of shape (n_classes,)
+        List of class labels.
+    cmap : str, default='Blues'
+        Colormap for the heatmap.
+    figsize : tuple, default=(10, 7)
+        Figure size.
+    """
+    # Calculer la matrice de confusion
+    cm = confusion_matrix(y_true, y_pred)
+    
+    # Créer une figure et un axe
+    fig, ax = plt.subplots(figsize=figsize)
+    
+    # Créer une heatmap
+    cax = ax.matshow(cm, cmap=cmap)
+    
+    # Ajouter une barre de couleur
+    fig.colorbar(cax)
+    
+    # Annoter la heatmap avec les valeurs de la matrice de confusion
+    for (i, j), val in np.ndenumerate(cm):
+        ax.text(j, i, f'{val}', ha='center', va='center', color='black')
+    
+    # Définir les labels des axes
+    ax.set_xticks(np.arange(len(classes)))
+    ax.set_yticks(np.arange(len(classes)))
+    ax.set_xticklabels(classes, rotation=45)
+    ax.set_yticklabels(classes)
+    
+    # Définir les labels des axes et le titre
+    plt.xlabel('Predicted Labels')
+    plt.ylabel('True Labels')
+    plt.title('Confusion Matrix')
+    
+    # Afficher la heatmap
+    plt.show()
+
+def evaluate_model_multiclass(y_true, y_pred, class_labels):
+    """
+    Evaluate a multi-class classification model using confusion matrix, classification metrics, 
+    and ROC AUC.
+
+    Args:
+        y_true (array): True class labels.
+        y_pred (array): Predicted class labels.
+        class_labels (array): List of class labels.
+    """
+    # Plot confusion matrix
+    plot_confusion_matrix(y_true, y_pred, classes=class_labels)
+
+    # Calculate classification metrics
+    # precision = precision_score(y_true, y_pred, average='weighted')
+    # recall = recall_score(y_true, y_pred, average='weighted')
+    f1 = f1_score(y_true, y_pred, average='weighted')
+
+    # Display classification metrics
+    print("*** Classification Metrics ***")
+    # print("Precision =", precision)
+    # print("Recall =", recall)
+    print("F1 Score =", f1)
+    print("******************************")
+
+    # Binarize the output
+    y_onehot_test = label_binarize(y_true, classes=class_labels)
+    y_score = label_binarize(y_pred, classes=class_labels)
+
+    # ROC AUC for multi-class classification
+    fpr = {}
+    tpr = {}
+    roc_auc = {}
+
+    for i in range(len(class_labels)):
+        fpr[i], tpr[i], _ = roc_curve(y_onehot_test[:, i], y_score[:, i])
+        roc_auc[i] = auc(fpr[i], tpr[i])
+
+    # Plot ROC curve for each class
+    plt.figure(figsize=(8, 6))
+    for i in range(len(class_labels)):
+        plt.plot(fpr[i], tpr[i], label='Class %d (AUC=%0.3f)' % (i, roc_auc[i]))
+
+    plt.plot([0, 1], [0, 1], 'k--')
+    plt.xlim([0.0, 1.0])
+    plt.ylim([0.0, 1.05])
+    plt.title('Multi-class ROC Curve')
+    plt.xlabel('False Positive Rate')
+    plt.ylabel('True Positive Rate')
+    plt.legend(loc="lower right")
+    plt.show()
+
+    # Print AUC scores for each class
+    print("AUC scores for each class:", roc_auc)
+
+    # Calculate micro-average ROC AUC
+    fpr["micro"], tpr["micro"], _ = roc_curve(y_onehot_test.ravel(), y_score.ravel())
+    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
+
+    # Plot micro-average ROC curve
+    plt.figure(figsize=(8, 6))
+    plt.plot(fpr["micro"], tpr["micro"], label='micro-average ROC curve (AUC = %0.3f)' % roc_auc["micro"], color='darkorange')
+    plt.plot([0, 1], [0, 1], 'k--')
+    plt.xlim([0.0, 1.0])
+    plt.ylim([0.0, 1.05])
+    plt.title('Micro-averaged ROC Curve')
+    plt.xlabel('False Positive Rate')
+    plt.ylabel('True Positive Rate')
+    plt.legend(loc="lower right")
+    plt.show()
+
+def plot_validation_curve(model, X_train, y_train, param_name, param_range, cv=5, scoring='f1_weighted', title=None):
+    """
+    Plot validation curve for a logistic regression model.
+
+    Args:
+        model: The logistic regression model to evaluate.
+        X_train: Training data.
+        y_train: Training labels.
+        param_name: Name of the parameter to vary.
+        param_range: Range of parameter values to evaluate.
+        cv: Number of cross-validation folds.
+        scoring: Scoring method to use.
+        title: Title of the plot.
+    """
+    train_scores, val_scores = validation_curve(
+        model, 
+        X_train, 
+        y_train, 
+        param_name=param_name, 
+        param_range=param_range, 
+        cv=cv, 
+        scoring=scoring
+    )
+
+    plt.figure(figsize=(12, 4))
+    plt.plot(param_range, train_scores.mean(axis=1), label='train')
+    plt.plot(param_range, val_scores.mean(axis=1), label='validation')
+    plt.legend()
+
+    print("train scores:", train_scores.mean(axis=1))
+    print("val scores:", val_scores.mean(axis=1))
+
+    # Find the best C (maximum validation score)
+    best_C_idx = np.argmax(val_scores.mean(axis=1))
+    best_C = param_range[best_C_idx]
+
+    if title is not None:
+        plt.title(f'{title} (Best C: {best_C:.5f})')
+    else:
+        # plt.title(f'Validation Curve for {model.__class__.__name__}')
+        plt.title(f'Validation Curve for Ridge Logistic Regression (Best C: {best_C:.5f})')
+    plt.ylabel('Score')
+    plt.xlabel(f'{param_name} (Regularization parameter)')
+    plt.show()
+