Diff of /utils.py [000000] .. [39fb2b]

Switch to unified view

a b/utils.py
1
import json
2
import logging
3
import os
4
import shutil
5
import torch
6
7
# from visdom import Visdom
8
# import numpy as np
9
# import matplotlib.pyplot as plt
10
11
class Params():
12
    """Class that loads hyperparameters from a json file.
13
14
    Example:
15
    ```
16
    params = Params(json_path)
17
    print(params.learning_rate)
18
    params.learning_rate = 0.5  # change the value of learning_rate in params
19
    ```
20
    """
21
22
    def __init__(self, json_path):
23
        with open(json_path) as f:
24
            params = json.load(f)
25
            self.__dict__.update(params)
26
27
    def save(self, json_path):
28
        with open(json_path, 'w') as f:
29
            json.dump(self.__dict__, f, indent=4)
30
            
31
    def update(self, json_path):
32
        """Loads parameters from json file"""
33
        with open(json_path) as f:
34
            params = json.load(f)
35
            self.__dict__.update(params)
36
37
    @property
38
    def dict(self):
39
        """Gives dict-like access to Params instance by `params.dict['learning_rate']"""
40
        return self.__dict__
41
42
43
class RunningAverage():
44
    """A simple class that maintains the running average of a quantity
45
    
46
    Example:
47
    ```
48
    loss_avg = RunningAverage()
49
    loss_avg.update(2)
50
    loss_avg.update(4)
51
    loss_avg() = 3
52
    ```
53
    """
54
    def __init__(self):
55
        self.steps = 0
56
        self.total = 0
57
    
58
    def update(self, val):
59
        self.total += val
60
        self.steps += 1
61
    
62
    def __call__(self):
63
        return self.total/float(self.steps)
64
        
65
    
66
def set_logger(log_path):
67
    """Set the logger to log info in terminal and file `log_path`.
68
69
    In general, it is useful to have a logger so that every output to the terminal is saved
70
    in a permanent file. Here we save it to `model_dir/train.log`.
71
72
    Example:
73
    ```
74
    logging.info("Starting training...")
75
    ```
76
77
    Args:
78
        log_path: (string) where to log
79
    """
80
    logger = logging.getLogger()
81
    logger.setLevel(logging.INFO)
82
83
    if not logger.handlers:
84
        # Logging to a file
85
        file_handler = logging.FileHandler(log_path)
86
        file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
87
        logger.addHandler(file_handler)
88
89
        # Logging to console
90
        stream_handler = logging.StreamHandler()
91
        stream_handler.setFormatter(logging.Formatter('%(message)s'))
92
        logger.addHandler(stream_handler)
93
94
95
def save_dict_to_json(d, json_path):
96
    """Saves dict of floats in json file
97
98
    Args:
99
        d: (dict) of float-castable values (np.float, int, float, etc.)
100
        json_path: (string) path to json file
101
    """
102
    with open(json_path, 'w') as f:
103
        # We need to convert the values to float for json (it doesn't accept np.array, np.float, )
104
        d = {k: float(v) for k, v in d.items()}
105
        json.dump(d, f, indent=4)
106
107
108
def save_checkpoint(state, is_best, checkpoint):
109
    """Saves model and training parameters at checkpoint + 'last.pth.tar'. If is_best==True, also saves
110
    checkpoint + 'best.pth.tar'
111
112
    Args:
113
        state: (dict) contains model's state_dict, may contain other keys such as epoch, optimizer state_dict
114
        is_best: (bool) True if it is the best model seen till now
115
        checkpoint: (string) folder where parameters are to be saved
116
    """
117
    filepath = os.path.join(checkpoint, 'last.pth.tar')
118
    if not os.path.exists(checkpoint):
119
        print("Checkpoint Directory does not exist! Making directory {}".format(checkpoint))
120
        os.mkdir(checkpoint)
121
    torch.save(state, filepath)
122
    if is_best:
123
        shutil.copyfile(filepath, os.path.join(checkpoint, 'best.pth.tar'))
124
125
126
def load_checkpoint(checkpoint, model, optimizer=None, mines=None, optims_mine=None, **kwargs):
127
    """Loads model parameters (state_dict) from file_path. If optimizer is provided, loads state_dict of
128
    optimizer assuming it is present in checkpoint.
129
130
    Args:
131
        checkpoint: (string) filename which needs to be loaded
132
        model: (torch.nn.Module) model for which the parameters are loaded
133
        optimizer: (torch.optim) optional: resume optimizer from checkpoint
134
        mines: dict of mine estimators
135
    """
136
    if not os.path.exists(checkpoint):
137
        raise("File doesn't exist {}".format(checkpoint))
138
    checkpoint = torch.load(checkpoint)
139
    model.load_state_dict(checkpoint['state_dict'], **kwargs)
140
141
    if optimizer:
142
        optimizer.load_state_dict(checkpoint['optim_dict'], **kwargs)
143
144
    if mines:
145
        for mi_name, mine in mines.items():
146
            mine.load_state_dict(checkpoint[mi_name])
147
            if optims_mine:
148
                optims_mine[mi_name].load_state_dict(checkpoint[mi_name+"_optim"])
149
150
    return checkpoint