Switch to unified view

a b/medicalbert/classifiers/standard/classifier.py
1
import gcsfs,logging, os, torch
2
import pandas as pd
3
from statistics import mean
4
from tqdm import trange, tqdm
5
6
###
7
# Base class for Bert classifiers.
8
###
9
class Classifier:
10
11
    def train(self, datareader):
12
        device = torch.device(self.config['device'])
13
        self.model.train()
14
        self.model.to(device)
15
16
        batch_losses = []
17
18
        for _ in trange(self.epochs, int(self.config['epochs']), desc="Epoch"):
19
            tr_loss = 0
20
            batche = []
21
            with tqdm(datareader.get_train(), desc="Iteration") as t:
22
                for step, batch in enumerate(t):
23
24
                    batch = tuple(t.to(device) for t in batch)
25
                    input_ids, input_mask, segment_ids, label_ids = batch
26
27
                    loss =  self.model(input_ids, labels=label_ids)[0]
28
29
                    # Statistics
30
                    batche.append(loss.item())
31
32
                    loss = loss / self.config['gradient_accumulation_steps']
33
34
                    loss.backward()
35
36
                    tr_loss += loss.item()
37
38
                    if (step + 1) % self.config['gradient_accumulation_steps'] == 0:
39
                        batch_losses.append(mean(batche))
40
                        # Update the model gradients
41
                        #torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
42
                        self.optimizer.step()
43
                        self.optimizer.zero_grad()
44
45
            # save a checkpoint here
46
            self.save()
47
            self.epochs = self.epochs+1
48
49
        self.save_batch_losses(pd.DataFrame(batch_losses))
50
51
    def save_batch_losses(self, losses):
52
        path = os.path.join(self.config['output_dir'], self.config['experiment_name'])
53
        if path[:2] != "gs":
54
            if not os.path.exists(path):
55
                os.makedirs(path)
56
57
        losses.to_csv(os.path.join(self.config['output_dir'], self.config['experiment_name'], "batch_loss.csv"))
58
59
    def set_eval_mode(self):
60
        self.model.eval()
61
62
    def load_from_checkpoint(self):
63
64
        if 'load_from_checkpoint' in self.config:
65
            file_path = os.path.join(self.config['output_dir'], "checkpoints", self.config['load_from_checkpoint'])
66
67
            checkpoint = torch.load(file_path)
68
            self.epochs = checkpoint['epoch']
69
            self.model.load_state_dict(checkpoint['bert_dict'])
70
            self.optimizer.load_state_dict(checkpoint['optimizer'])
71
72
            # work around - for some reason reloading an optimizer that worked with CUDA tensors
73
            # causes an error - see https://github.com/pytorch/pytorch/issues/2830
74
            for state in self.optimizer.state.values():
75
                for k, v in state.items():
76
                    if isinstance(v, torch.Tensor):
77
                        if self.config['device'] == 'gpu':
78
                            state[k] = v.cuda()
79
                        else:
80
                            state[k] = v
81
82
    def load_object_from_location(self, checkpoint_file):
83
        if checkpoint_file[:2] != "gs":
84
            return torch.load(checkpoint_file)
85
        else:
86
87
            fs = gcsfs.GCSFileSystem()
88
            with fs.open(checkpoint_file, mode='rb') as f:
89
                return torch.load(f)
90
91
    def load_from_checkpoint(self, checkpoint_file):
92
        file_path = os.path.join(self.config['output_dir'], self.config['experiment_name'],"checkpoints", checkpoint_file)
93
        checkpoint = self.load_object_from_location(file_path)
94
95
        self.epochs = checkpoint['epoch']
96
        self.model.load_state_dict(checkpoint['bert_dict'])
97
        self.optimizer.load_state_dict(checkpoint['optimizer'])
98
99
        # work around - for some reason reloading an optimizer that worked with CUDA tensors
100
        # causes an error - see https://github.com/pytorch/pytorch/issues/2830
101
        for state in self.optimizer.state.values():
102
            for k, v in state.items():
103
                if isinstance(v, torch.Tensor):
104
                    if self.config['device'] == 'gpu':
105
                        state[k] = v.cuda()
106
                    else:
107
                        state[k] = v
108
109
    def save_object_to_location(self, object):
110
111
        if self.config['output_dir'][:2] != "gs":
112
            if not os.path.exists(
113
                    os.path.join(self.config['output_dir'], self.config['experiment_name'], "checkpoints")):
114
                os.makedirs(os.path.join(self.config['output_dir'], self.config['experiment_name'], "checkpoints"))
115
            torch.save(object,
116
                       os.path.join(self.config['output_dir'], self.config['experiment_name'], "checkpoints",
117
                                    str(self.epochs)))
118
        else:
119
            fs = gcsfs.GCSFileSystem()
120
            file_name = os.path.join(self.config['output_dir'], self.config['experiment_name'], "checkpoints",
121
                                    str(self.epochs))
122
            with fs.open(file_name, mode='wb') as f:
123
                return torch.save(object, f)
124
125
    def save(self):
126
        checkpoint = {
127
            'epoch': self.epochs + 1,
128
            'bert_dict': self.model.state_dict(),
129
            'optimizer': self.optimizer.state_dict(),
130
        }
131
        self.save_object_to_location(checkpoint)
132
        logging.info("Saved model")