a b/medicalbert/classifiers/sequential/sequence_classifier.py
1
import pandas as pd
2
import torch
3
from tqdm import trange, tqdm
4
5
from classifiers.standard.classifier import Classifier
6
7
class SequenceClassifier(Classifier):
8
9
    def train(self, datareader):
10
        device = torch.device(self.config['device'])
11
        self.model.train()
12
        self.model.to(device)
13
14
        batch_losses = []
15
16
        for _ in trange(self.epochs, int(self.config['epochs']), desc="Epoch"):
17
            tr_loss = 0
18
            with tqdm(datareader.get_train(), desc="Iteration") as t:
19
                for step, batch in enumerate(t):
20
21
                    batch = tuple(t.to(device) for t in batch)
22
                    input_ids, input_mask, segment_ids, label_ids = batch
23
24
                    loss =  self.model(input_ids, labels=label_ids)[0]
25
26
27
                    loss = loss / self.config['gradient_accumulation_steps']
28
29
                    loss.backward()
30
31
                    tr_loss += loss.item()
32
33
                    if (step + 1) % self.config['gradient_accumulation_steps'] == 0:
34
                        # Update the model gradients
35
                        #torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
36
                        self.optimizer.step()
37
                        self.optimizer.zero_grad()
38
39
            # save a checkpoint here
40
            self.save()
41
            self.epochs = self.epochs+1
42
43
        self.save_batch_losses(pd.DataFrame(batch_losses))