Diff of /src/utils/training.py [000000] .. [0eda78]

Switch to unified view

a b/src/utils/training.py
1
from utils.dataloader import Dataloader
2
from utils.BertArchitecture import BertNER
3
from utils.BertArchitecture import BioBertNER
4
from utils.metric_tracking import MetricsTracking
5
6
import torch
7
from torch.optim import SGD
8
from torch.utils.data import DataLoader
9
10
import numpy as np
11
import pandas as pd
12
13
from tqdm import tqdm
14
15
def train_loop(model, train_dataset, eval_dataset, optimizer, batch_size, epochs, type, train_sampler=None, eval_sampler=None, verbose=True):
16
    """
17
    Usual training loop, including training and evaluation.
18
19
    Parameters:
20
    model (BertNER | BioBertNER): Model to be trained.
21
    train_dataset (Custom_Dataset): Dataset used for training.
22
    eval_dataset (Custom_Dataset): Dataset used for testing.
23
    optimizer (torch.optim): Optimizer used, usually SGD or Adam.
24
    batch_size (int): Batch size used during training.
25
    epochs (int): Number of epochs used for training.
26
    train_sampler (SubsetRandomSampler): Sampler used during hyperparameter-tuning.
27
    val_subsampler (SubsetRandomSampler): Sampler used during hyperparameter-tuning.
28
    verbose (bool): Whether the model should be evaluated after each epoch or not.
29
30
    Returns:
31
    tuple:
32
        - train_res (dict): A dictionary containing the results obtained during training.
33
        - test_res (dict): A dictionary containing the results obtained during testing.
34
    """
35
36
    if train_sampler == None or eval_sampler == None:
37
        train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = False, sampler=train_sampler)
38
        eval_dataloader = DataLoader(eval_dataset, batch_size = batch_size, shuffle = False, sampler=eval_sampler)
39
    else:
40
        train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = False)
41
        eval_dataloader = DataLoader(eval_dataset, batch_size = batch_size, shuffle = False)
42
43
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
    model = model.to(device)
45
46
    #training
47
    for epoch in range(epochs):
48
49
        train_metrics = MetricsTracking(type)
50
51
        model.train() #train mode
52
53
        for train_data in tqdm(train_dataloader):
54
55
            train_label = train_data['entity'].to(device)
56
            mask = train_data['attention_mask'].squeeze(1).to(device)
57
            input_id = train_data['input_ids'].squeeze(1).to(device)
58
59
            optimizer.zero_grad()
60
61
            output = model(input_id, mask, train_label)
62
            loss, logits = output.loss, output.logits
63
            predictions = logits.argmax(dim=-1)
64
65
            #compute metrics
66
            train_metrics.update(predictions, train_label, loss.item())
67
68
            loss.backward()
69
            optimizer.step()
70
71
        if verbose:
72
            model.eval() #evaluation mode
73
74
            eval_metrics = MetricsTracking(type)
75
76
            with torch.no_grad():
77
78
                for eval_data in eval_dataloader:
79
80
                    eval_label = eval_data['entity'].to(device)
81
                    mask = eval_data['attention_mask'].squeeze(1).to(device)
82
                    input_id = eval_data['input_ids'].squeeze(1).to(device)
83
84
                    output = model(input_id, mask, eval_label)
85
                    loss, logits = output.loss, output.logits
86
87
                    predictions = logits.argmax(dim=-1)
88
89
                    eval_metrics.update(predictions, eval_label, loss.item())
90
91
            train_results = train_metrics.return_avg_metrics(len(train_dataloader))
92
            eval_results = eval_metrics.return_avg_metrics(len(eval_dataloader))
93
94
            print(f"Epoch {epoch+1} of {epochs} finished!")
95
            print(f"TRAIN\nMetrics {train_results}\n")
96
            print(f"VALIDATION\nMetrics {eval_results}\n")
97
98
    if not verbose:
99
        model.eval() #evaluation mode
100
101
        eval_metrics = MetricsTracking(type)
102
103
        with torch.no_grad():
104
105
            for eval_data in eval_dataloader:
106
107
                eval_label = eval_data['entity'].to(device)
108
                mask = eval_data['attention_mask'].squeeze(1).to(device)
109
                input_id = eval_data['input_ids'].squeeze(1).to(device)
110
111
                output = model(input_id, mask, eval_label)
112
                loss, logits = output.loss, output.logits
113
114
                predictions = logits.argmax(dim=-1)
115
116
                eval_metrics.update(predictions, eval_label, loss.item())
117
118
        train_results = train_metrics.return_avg_metrics(len(train_dataloader))
119
        eval_results = eval_metrics.return_avg_metrics(len(eval_dataloader))
120
121
        print(f"Epoch {epoch+1} of {epochs} finished!")
122
        print(f"TRAIN\nMetrics {train_results}\n")
123
        print(f"VALIDATION\nMetrics {eval_results}\n")
124
125
    return train_results, eval_results
126
127
def testing(model, test_dataset, batch_size, type):
128
    """
129
    Function for testing a trained model.
130
131
    Parameters:
132
    model (BertNER | BioBertNER): Model to be tested
133
    train_dataset (Custom_Dataset): Dataset used for testing
134
    batch_size (int): Batch size used during training.
135
136
    Returns:
137
    tuple:
138
        - test_res (dict): A dictionary containing the results obtained during testing.
139
    """
140
141
    test_dataloader = DataLoader(test_dataset, batch_size = batch_size, shuffle = False)
142
143
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
144
    model = model.to(device)
145
146
    model.eval() #evaluation mode
147
148
    test_metrics = MetricsTracking(type)
149
150
    with torch.no_grad():
151
152
        for test_data in test_dataloader:
153
154
            test_label = test_data['entity'].to(device)
155
            mask = test_data['attention_mask'].squeeze(1).to(device)
156
            input_id = test_data['input_ids'].squeeze(1).to(device)
157
158
            output = model(input_id, mask, test_label)
159
            loss, logits = output.loss, output.logits
160
161
            predictions = logits.argmax(dim=-1)
162
163
            test_metrics.update(predictions, test_label, loss.item())
164
165
        test_results = test_metrics.return_avg_metrics(len(test_dataloader))
166
167
        print(f"TEST\nMetrics {test_results}\n")
168
169
    return test_results