Diff of /src/utils/training.py [000000] .. [0eda78]

Switch to side-by-side view

--- a
+++ b/src/utils/training.py
@@ -0,0 +1,169 @@
+from utils.dataloader import Dataloader
+from utils.BertArchitecture import BertNER
+from utils.BertArchitecture import BioBertNER
+from utils.metric_tracking import MetricsTracking
+
+import torch
+from torch.optim import SGD
+from torch.utils.data import DataLoader
+
+import numpy as np
+import pandas as pd
+
+from tqdm import tqdm
+
+def train_loop(model, train_dataset, eval_dataset, optimizer, batch_size, epochs, type, train_sampler=None, eval_sampler=None, verbose=True):
+    """
+    Usual training loop, including training and evaluation.
+
+    Parameters:
+    model (BertNER | BioBertNER): Model to be trained.
+    train_dataset (Custom_Dataset): Dataset used for training.
+    eval_dataset (Custom_Dataset): Dataset used for testing.
+    optimizer (torch.optim): Optimizer used, usually SGD or Adam.
+    batch_size (int): Batch size used during training.
+    epochs (int): Number of epochs used for training.
+    train_sampler (SubsetRandomSampler): Sampler used during hyperparameter-tuning.
+    val_subsampler (SubsetRandomSampler): Sampler used during hyperparameter-tuning.
+    verbose (bool): Whether the model should be evaluated after each epoch or not.
+
+    Returns:
+    tuple:
+        - train_res (dict): A dictionary containing the results obtained during training.
+        - test_res (dict): A dictionary containing the results obtained during testing.
+    """
+
+    if train_sampler == None or eval_sampler == None:
+        train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = False, sampler=train_sampler)
+        eval_dataloader = DataLoader(eval_dataset, batch_size = batch_size, shuffle = False, sampler=eval_sampler)
+    else:
+        train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = False)
+        eval_dataloader = DataLoader(eval_dataset, batch_size = batch_size, shuffle = False)
+
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    model = model.to(device)
+
+    #training
+    for epoch in range(epochs):
+
+        train_metrics = MetricsTracking(type)
+
+        model.train() #train mode
+
+        for train_data in tqdm(train_dataloader):
+
+            train_label = train_data['entity'].to(device)
+            mask = train_data['attention_mask'].squeeze(1).to(device)
+            input_id = train_data['input_ids'].squeeze(1).to(device)
+
+            optimizer.zero_grad()
+
+            output = model(input_id, mask, train_label)
+            loss, logits = output.loss, output.logits
+            predictions = logits.argmax(dim=-1)
+
+            #compute metrics
+            train_metrics.update(predictions, train_label, loss.item())
+
+            loss.backward()
+            optimizer.step()
+
+        if verbose:
+            model.eval() #evaluation mode
+
+            eval_metrics = MetricsTracking(type)
+
+            with torch.no_grad():
+
+                for eval_data in eval_dataloader:
+
+                    eval_label = eval_data['entity'].to(device)
+                    mask = eval_data['attention_mask'].squeeze(1).to(device)
+                    input_id = eval_data['input_ids'].squeeze(1).to(device)
+
+                    output = model(input_id, mask, eval_label)
+                    loss, logits = output.loss, output.logits
+
+                    predictions = logits.argmax(dim=-1)
+
+                    eval_metrics.update(predictions, eval_label, loss.item())
+
+            train_results = train_metrics.return_avg_metrics(len(train_dataloader))
+            eval_results = eval_metrics.return_avg_metrics(len(eval_dataloader))
+
+            print(f"Epoch {epoch+1} of {epochs} finished!")
+            print(f"TRAIN\nMetrics {train_results}\n")
+            print(f"VALIDATION\nMetrics {eval_results}\n")
+
+    if not verbose:
+        model.eval() #evaluation mode
+
+        eval_metrics = MetricsTracking(type)
+
+        with torch.no_grad():
+
+            for eval_data in eval_dataloader:
+
+                eval_label = eval_data['entity'].to(device)
+                mask = eval_data['attention_mask'].squeeze(1).to(device)
+                input_id = eval_data['input_ids'].squeeze(1).to(device)
+
+                output = model(input_id, mask, eval_label)
+                loss, logits = output.loss, output.logits
+
+                predictions = logits.argmax(dim=-1)
+
+                eval_metrics.update(predictions, eval_label, loss.item())
+
+        train_results = train_metrics.return_avg_metrics(len(train_dataloader))
+        eval_results = eval_metrics.return_avg_metrics(len(eval_dataloader))
+
+        print(f"Epoch {epoch+1} of {epochs} finished!")
+        print(f"TRAIN\nMetrics {train_results}\n")
+        print(f"VALIDATION\nMetrics {eval_results}\n")
+
+    return train_results, eval_results
+
+def testing(model, test_dataset, batch_size, type):
+    """
+    Function for testing a trained model.
+
+    Parameters:
+    model (BertNER | BioBertNER): Model to be tested
+    train_dataset (Custom_Dataset): Dataset used for testing
+    batch_size (int): Batch size used during training.
+
+    Returns:
+    tuple:
+        - test_res (dict): A dictionary containing the results obtained during testing.
+    """
+
+    test_dataloader = DataLoader(test_dataset, batch_size = batch_size, shuffle = False)
+
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    model = model.to(device)
+
+    model.eval() #evaluation mode
+
+    test_metrics = MetricsTracking(type)
+
+    with torch.no_grad():
+
+        for test_data in test_dataloader:
+
+            test_label = test_data['entity'].to(device)
+            mask = test_data['attention_mask'].squeeze(1).to(device)
+            input_id = test_data['input_ids'].squeeze(1).to(device)
+
+            output = model(input_id, mask, test_label)
+            loss, logits = output.loss, output.logits
+
+            predictions = logits.argmax(dim=-1)
+
+            test_metrics.update(predictions, test_label, loss.item())
+
+        test_results = test_metrics.return_avg_metrics(len(test_dataloader))
+
+        print(f"TEST\nMetrics {test_results}\n")
+
+    return test_results