|
a |
|
b/src/utils/training.py |
|
|
1 |
from utils.dataloader import Dataloader |
|
|
2 |
from utils.BertArchitecture import BertNER |
|
|
3 |
from utils.BertArchitecture import BioBertNER |
|
|
4 |
from utils.metric_tracking import MetricsTracking |
|
|
5 |
|
|
|
6 |
import torch |
|
|
7 |
from torch.optim import SGD |
|
|
8 |
from torch.utils.data import DataLoader |
|
|
9 |
|
|
|
10 |
import numpy as np |
|
|
11 |
import pandas as pd |
|
|
12 |
|
|
|
13 |
from tqdm import tqdm |
|
|
14 |
|
|
|
15 |
def train_loop(model, train_dataset, eval_dataset, optimizer, batch_size, epochs, type, train_sampler=None, eval_sampler=None, verbose=True): |
|
|
16 |
""" |
|
|
17 |
Usual training loop, including training and evaluation. |
|
|
18 |
|
|
|
19 |
Parameters: |
|
|
20 |
model (BertNER | BioBertNER): Model to be trained. |
|
|
21 |
train_dataset (Custom_Dataset): Dataset used for training. |
|
|
22 |
eval_dataset (Custom_Dataset): Dataset used for testing. |
|
|
23 |
optimizer (torch.optim): Optimizer used, usually SGD or Adam. |
|
|
24 |
batch_size (int): Batch size used during training. |
|
|
25 |
epochs (int): Number of epochs used for training. |
|
|
26 |
train_sampler (SubsetRandomSampler): Sampler used during hyperparameter-tuning. |
|
|
27 |
val_subsampler (SubsetRandomSampler): Sampler used during hyperparameter-tuning. |
|
|
28 |
verbose (bool): Whether the model should be evaluated after each epoch or not. |
|
|
29 |
|
|
|
30 |
Returns: |
|
|
31 |
tuple: |
|
|
32 |
- train_res (dict): A dictionary containing the results obtained during training. |
|
|
33 |
- test_res (dict): A dictionary containing the results obtained during testing. |
|
|
34 |
""" |
|
|
35 |
|
|
|
36 |
if train_sampler == None or eval_sampler == None: |
|
|
37 |
train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = False, sampler=train_sampler) |
|
|
38 |
eval_dataloader = DataLoader(eval_dataset, batch_size = batch_size, shuffle = False, sampler=eval_sampler) |
|
|
39 |
else: |
|
|
40 |
train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = False) |
|
|
41 |
eval_dataloader = DataLoader(eval_dataset, batch_size = batch_size, shuffle = False) |
|
|
42 |
|
|
|
43 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
44 |
model = model.to(device) |
|
|
45 |
|
|
|
46 |
#training |
|
|
47 |
for epoch in range(epochs): |
|
|
48 |
|
|
|
49 |
train_metrics = MetricsTracking(type) |
|
|
50 |
|
|
|
51 |
model.train() #train mode |
|
|
52 |
|
|
|
53 |
for train_data in tqdm(train_dataloader): |
|
|
54 |
|
|
|
55 |
train_label = train_data['entity'].to(device) |
|
|
56 |
mask = train_data['attention_mask'].squeeze(1).to(device) |
|
|
57 |
input_id = train_data['input_ids'].squeeze(1).to(device) |
|
|
58 |
|
|
|
59 |
optimizer.zero_grad() |
|
|
60 |
|
|
|
61 |
output = model(input_id, mask, train_label) |
|
|
62 |
loss, logits = output.loss, output.logits |
|
|
63 |
predictions = logits.argmax(dim=-1) |
|
|
64 |
|
|
|
65 |
#compute metrics |
|
|
66 |
train_metrics.update(predictions, train_label, loss.item()) |
|
|
67 |
|
|
|
68 |
loss.backward() |
|
|
69 |
optimizer.step() |
|
|
70 |
|
|
|
71 |
if verbose: |
|
|
72 |
model.eval() #evaluation mode |
|
|
73 |
|
|
|
74 |
eval_metrics = MetricsTracking(type) |
|
|
75 |
|
|
|
76 |
with torch.no_grad(): |
|
|
77 |
|
|
|
78 |
for eval_data in eval_dataloader: |
|
|
79 |
|
|
|
80 |
eval_label = eval_data['entity'].to(device) |
|
|
81 |
mask = eval_data['attention_mask'].squeeze(1).to(device) |
|
|
82 |
input_id = eval_data['input_ids'].squeeze(1).to(device) |
|
|
83 |
|
|
|
84 |
output = model(input_id, mask, eval_label) |
|
|
85 |
loss, logits = output.loss, output.logits |
|
|
86 |
|
|
|
87 |
predictions = logits.argmax(dim=-1) |
|
|
88 |
|
|
|
89 |
eval_metrics.update(predictions, eval_label, loss.item()) |
|
|
90 |
|
|
|
91 |
train_results = train_metrics.return_avg_metrics(len(train_dataloader)) |
|
|
92 |
eval_results = eval_metrics.return_avg_metrics(len(eval_dataloader)) |
|
|
93 |
|
|
|
94 |
print(f"Epoch {epoch+1} of {epochs} finished!") |
|
|
95 |
print(f"TRAIN\nMetrics {train_results}\n") |
|
|
96 |
print(f"VALIDATION\nMetrics {eval_results}\n") |
|
|
97 |
|
|
|
98 |
if not verbose: |
|
|
99 |
model.eval() #evaluation mode |
|
|
100 |
|
|
|
101 |
eval_metrics = MetricsTracking(type) |
|
|
102 |
|
|
|
103 |
with torch.no_grad(): |
|
|
104 |
|
|
|
105 |
for eval_data in eval_dataloader: |
|
|
106 |
|
|
|
107 |
eval_label = eval_data['entity'].to(device) |
|
|
108 |
mask = eval_data['attention_mask'].squeeze(1).to(device) |
|
|
109 |
input_id = eval_data['input_ids'].squeeze(1).to(device) |
|
|
110 |
|
|
|
111 |
output = model(input_id, mask, eval_label) |
|
|
112 |
loss, logits = output.loss, output.logits |
|
|
113 |
|
|
|
114 |
predictions = logits.argmax(dim=-1) |
|
|
115 |
|
|
|
116 |
eval_metrics.update(predictions, eval_label, loss.item()) |
|
|
117 |
|
|
|
118 |
train_results = train_metrics.return_avg_metrics(len(train_dataloader)) |
|
|
119 |
eval_results = eval_metrics.return_avg_metrics(len(eval_dataloader)) |
|
|
120 |
|
|
|
121 |
print(f"Epoch {epoch+1} of {epochs} finished!") |
|
|
122 |
print(f"TRAIN\nMetrics {train_results}\n") |
|
|
123 |
print(f"VALIDATION\nMetrics {eval_results}\n") |
|
|
124 |
|
|
|
125 |
return train_results, eval_results |
|
|
126 |
|
|
|
127 |
def testing(model, test_dataset, batch_size, type): |
|
|
128 |
""" |
|
|
129 |
Function for testing a trained model. |
|
|
130 |
|
|
|
131 |
Parameters: |
|
|
132 |
model (BertNER | BioBertNER): Model to be tested |
|
|
133 |
train_dataset (Custom_Dataset): Dataset used for testing |
|
|
134 |
batch_size (int): Batch size used during training. |
|
|
135 |
|
|
|
136 |
Returns: |
|
|
137 |
tuple: |
|
|
138 |
- test_res (dict): A dictionary containing the results obtained during testing. |
|
|
139 |
""" |
|
|
140 |
|
|
|
141 |
test_dataloader = DataLoader(test_dataset, batch_size = batch_size, shuffle = False) |
|
|
142 |
|
|
|
143 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
144 |
model = model.to(device) |
|
|
145 |
|
|
|
146 |
model.eval() #evaluation mode |
|
|
147 |
|
|
|
148 |
test_metrics = MetricsTracking(type) |
|
|
149 |
|
|
|
150 |
with torch.no_grad(): |
|
|
151 |
|
|
|
152 |
for test_data in test_dataloader: |
|
|
153 |
|
|
|
154 |
test_label = test_data['entity'].to(device) |
|
|
155 |
mask = test_data['attention_mask'].squeeze(1).to(device) |
|
|
156 |
input_id = test_data['input_ids'].squeeze(1).to(device) |
|
|
157 |
|
|
|
158 |
output = model(input_id, mask, test_label) |
|
|
159 |
loss, logits = output.loss, output.logits |
|
|
160 |
|
|
|
161 |
predictions = logits.argmax(dim=-1) |
|
|
162 |
|
|
|
163 |
test_metrics.update(predictions, test_label, loss.item()) |
|
|
164 |
|
|
|
165 |
test_results = test_metrics.return_avg_metrics(len(test_dataloader)) |
|
|
166 |
|
|
|
167 |
print(f"TEST\nMetrics {test_results}\n") |
|
|
168 |
|
|
|
169 |
return test_results |