[4fa73e]: / pytorch / agents / base.py

Download this file

67 lines (56 with data), 1.7 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""
The Base Agent class, where all other agents inherit from, that contains definitions for all the necessary functions
"""
import logging
class BaseAgent:
"""
This base class will contain the base functions to be overloaded by any agent you will implement.
"""
def __init__(self, config):
self.config = config
self.logger = logging.getLogger("Agent")
def load_checkpoint(self, file_name):
"""
Latest checkpoint loader
:param file_name: name of the checkpoint file
:return:
"""
raise NotImplementedError
def save_checkpoint(self, file_name="checkpoint.pth.tar", is_best=0):
"""
Checkpoint saver
:param file_name: name of the checkpoint file
:param is_best: boolean flag to indicate whether current checkpoint's metric is the best so far
:return:
"""
raise NotImplementedError
def run(self):
"""
The main operator
:return:
"""
raise NotImplementedError
def train(self):
"""
Main training loop
:return:
"""
raise NotImplementedError
def train_one_epoch(self):
"""
One epoch of training
:return:
"""
raise NotImplementedError
def validate(self):
"""
One cycle of model validation
:return:
"""
raise NotImplementedError
def finalize(self):
"""
Finalizes all the operations of the 2 Main classes of the process, the operator and the data loader
:return:
"""
raise NotImplementedError