[d129b2]: / medicalbert / classifiers / sequential / sequence_classifier.py

Download this file

43 lines (29 with data), 1.4 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import pandas as pd
import torch
from tqdm import trange, tqdm
from classifiers.standard.classifier import Classifier
class SequenceClassifier(Classifier):
def train(self, datareader):
device = torch.device(self.config['device'])
self.model.train()
self.model.to(device)
batch_losses = []
for _ in trange(self.epochs, int(self.config['epochs']), desc="Epoch"):
tr_loss = 0
with tqdm(datareader.get_train(), desc="Iteration") as t:
for step, batch in enumerate(t):
batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
loss = self.model(input_ids, labels=label_ids)[0]
loss = loss / self.config['gradient_accumulation_steps']
loss.backward()
tr_loss += loss.item()
if (step + 1) % self.config['gradient_accumulation_steps'] == 0:
# Update the model gradients
#torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
self.optimizer.zero_grad()
# save a checkpoint here
self.save()
self.epochs = self.epochs+1
self.save_batch_losses(pd.DataFrame(batch_losses))