|
a |
|
b/medicalbert/classifiers/sequential/sequence_classifier.py |
|
|
1 |
import pandas as pd |
|
|
2 |
import torch |
|
|
3 |
from tqdm import trange, tqdm |
|
|
4 |
|
|
|
5 |
from classifiers.standard.classifier import Classifier |
|
|
6 |
|
|
|
7 |
class SequenceClassifier(Classifier): |
|
|
8 |
|
|
|
9 |
def train(self, datareader): |
|
|
10 |
device = torch.device(self.config['device']) |
|
|
11 |
self.model.train() |
|
|
12 |
self.model.to(device) |
|
|
13 |
|
|
|
14 |
batch_losses = [] |
|
|
15 |
|
|
|
16 |
for _ in trange(self.epochs, int(self.config['epochs']), desc="Epoch"): |
|
|
17 |
tr_loss = 0 |
|
|
18 |
with tqdm(datareader.get_train(), desc="Iteration") as t: |
|
|
19 |
for step, batch in enumerate(t): |
|
|
20 |
|
|
|
21 |
batch = tuple(t.to(device) for t in batch) |
|
|
22 |
input_ids, input_mask, segment_ids, label_ids = batch |
|
|
23 |
|
|
|
24 |
loss = self.model(input_ids, labels=label_ids)[0] |
|
|
25 |
|
|
|
26 |
|
|
|
27 |
loss = loss / self.config['gradient_accumulation_steps'] |
|
|
28 |
|
|
|
29 |
loss.backward() |
|
|
30 |
|
|
|
31 |
tr_loss += loss.item() |
|
|
32 |
|
|
|
33 |
if (step + 1) % self.config['gradient_accumulation_steps'] == 0: |
|
|
34 |
# Update the model gradients |
|
|
35 |
#torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) |
|
|
36 |
self.optimizer.step() |
|
|
37 |
self.optimizer.zero_grad() |
|
|
38 |
|
|
|
39 |
# save a checkpoint here |
|
|
40 |
self.save() |
|
|
41 |
self.epochs = self.epochs+1 |
|
|
42 |
|
|
|
43 |
self.save_batch_losses(pd.DataFrame(batch_losses)) |