--- a +++ b/src/multi_therapy.py @@ -0,0 +1,165 @@ +import kagglehub +import os +import pandas as pd +from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler, OneHotEncoder +from sklearn.compose import ColumnTransformer +from torch.utils.data import DataLoader, TensorDataset + +# Download the METABRIC dataset from KaggleHub +path = kagglehub.dataset_download("raghadalharbi/breast-cancer-gene-expression-profiles-metabric") +print("Path to dataset files:", path) + +# Load dataset +df = pd.read_csv(os.path.join(path, "METABRIC_RNA_Mutation.csv")) +device = torch.device("cpu") + +# Selected features for prediction +selected_features = [ + "age_at_diagnosis", "tumor_size", "tumor_stage", "lymph_nodes_examined_positive", + "nottingham_prognostic_index", "cellularity", "neoplasm_histologic_grade", + "inferred_menopausal_state", "er_status_measured_by_ihc", "pr_status", "her2_status" +] + +# Create a simplified therapy class label: +# 0 = Chemotherapy only, 1 = Radiotherapy only, 2 = Hormone therapy only +df["therapy_class_simple"] = -1 +df.loc[(df["chemotherapy"] == 1) & (df["hormone_therapy"] == 0) & (df["radio_therapy"] == 0), "therapy_class_simple"] = 0 +df.loc[(df["chemotherapy"] == 0) & (df["hormone_therapy"] == 0) & (df["radio_therapy"] == 1), "therapy_class_simple"] = 1 +df.loc[(df["chemotherapy"] == 0) & (df["hormone_therapy"] == 1) & (df["radio_therapy"] == 0), "therapy_class_simple"] = 2 + +# Keep only valid rows +df = df[df["therapy_class_simple"] != -1].copy() +df = df.dropna(subset=selected_features + ["therapy_class_simple"]) + +# Define features and target +X = df[selected_features].copy() +y = df["therapy_class_simple"].astype(int) + +# Identify categorical and numerical columns +categorical_cols = X.select_dtypes(include=["object", "bool", "category"]).columns.tolist() +X[categorical_cols] = X[categorical_cols].astype(str) +numeric_cols = X.select_dtypes(include=["int64", "float64"]).columns.tolist() + +# Preprocessing pipeline with scaling and one-hot encoding +preprocessor = ColumnTransformer([ + ("num", StandardScaler(), numeric_cols), + ("cat", OneHotEncoder(drop="first", handle_unknown="ignore"), categorical_cols) +]) +X_processed = preprocessor.fit_transform(X) + +input_dim = X_processed.shape[1] +num_classes = len(np.unique(y)) + +# Convert to PyTorch tensors +X_tensor = torch.tensor(X_processed, dtype=torch.float32).to(device) +y_tensor = torch.tensor(y.values, dtype=torch.long).to(device) + +# Split into training and test sets +X_train, X_test, y_train, y_test = train_test_split( + X_tensor, y_tensor, test_size=0.2, stratify=y, random_state=42 +) +train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=32, shuffle=True) + +# Define the neural network model for multiclass classification +class MulticlassTherapyModel(nn.Module): + def __init__(self, in_dim, out_dim): + super().__init__() + self.net = nn.Sequential( + nn.Linear(in_dim, 32), + nn.ReLU(), + nn.Dropout(0.5), + nn.Linear(32, 16), + nn.ReLU(), + nn.Dropout(0.5), + nn.Linear(16, out_dim) + ) + def forward(self, x): + return self.net(x) + +# Initialize model, loss function, and optimizer +model = MulticlassTherapyModel(input_dim, num_classes).to(device) +loss_fn = nn.CrossEntropyLoss() +optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + +# Setup for early stopping +best_val_loss = float("inf") +patience = 10 +counter = 0 +best_model_state = None +train_losses, val_losses = [], [] + +# Training loop +for epoch in range(200): + model.train() + epoch_loss = 0 + for xb, yb in train_loader: + preds = model(xb) + loss = loss_fn(preds, yb) + optimizer.zero_grad() + loss.backward() + optimizer.step() + epoch_loss += loss.item() + train_losses.append(epoch_loss / len(train_loader)) + + # Validation loss + model.eval() + with torch.no_grad(): + val_preds = model(X_test) + val_loss = loss_fn(val_preds, y_test) + val_losses.append(val_loss.item()) + + print(f"Epoch {epoch+1}/200 | Train Loss: {train_losses[-1]:.4f} | Val Loss: {val_losses[-1]:.4f}") + + # Check early stopping condition + if val_loss.item() < best_val_loss: + best_val_loss = val_loss.item() + best_model_state = model.state_dict() + counter = 0 + else: + counter += 1 + if counter >= patience: + print(f"Early stopping at epoch {epoch+1}") + break + +# Restore best model +if best_model_state is not None: + model.load_state_dict(best_model_state) + +# Plot and save learning curve +os.makedirs("../outputs/figures", exist_ok=True) +plt.figure(figsize=(8, 5)) +plt.plot(train_losses, label="Train Loss") +plt.plot(val_losses, label="Validation Loss") +plt.xlabel("Epoch") +plt.ylabel("CrossEntropy Loss") +plt.title("Learning Curve (Therapy: Chemo vs Radio vs Hormon)") +plt.legend() +plt.tight_layout() +plt.savefig("../outputs/figures/nn_learning_curve_therapy_simple.png") + +# Model evaluation +model.eval() +with torch.no_grad(): + logits = model(X_test) + y_pred = torch.argmax(logits, dim=1).cpu().numpy() + y_true = y_test.cpu().numpy() + +# Print classification report +report = classification_report(y_true, y_pred, output_dict=True) +report_df = pd.DataFrame(report).transpose() +print(report_df) + +# Plot and save confusion matrix +cm = confusion_matrix(y_true, y_pred) +disp = ConfusionMatrixDisplay(confusion_matrix=cm) +disp.plot(cmap="Blues", xticks_rotation=45) +plt.title("Confusion Matrix - Multiclass Therapy Prediction") +plt.tight_layout() +os.makedirs("../outputs/figures", exist_ok=True) +plt.savefig("../outputs/figures/confusion_matrix_multiclass_therapy.png")