Diff of /baselines/logger.py [000000] .. [f9c9f2]

Switch to unified view

a b/baselines/logger.py
1
import os
2
import sys
3
import shutil
4
import os.path as osp
5
import json
6
import time
7
import datetime
8
import tempfile
9
from collections import defaultdict
10
11
LOG_OUTPUT_FORMATS     = ['stdout', 'log', 'csv']
12
LOG_OUTPUT_FORMATS_MPI = ['log']
13
# Also valid: json, tensorboard
14
15
DEBUG = 10
16
INFO = 20
17
WARN = 30
18
ERROR = 40
19
20
DISABLED = 50
21
22
class KVWriter(object):
23
    def writekvs(self, kvs):
24
        raise NotImplementedError
25
26
class SeqWriter(object):
27
    def writeseq(self, seq):
28
        raise NotImplementedError
29
30
class HumanOutputFormat(KVWriter, SeqWriter):
31
    def __init__(self, filename_or_file):
32
        if isinstance(filename_or_file, str):
33
            self.file = open(filename_or_file, 'wt')
34
            self.own_file = True
35
        else:
36
            assert hasattr(filename_or_file, 'read'), 'expected file or str, got %s'%filename_or_file
37
            self.file = filename_or_file
38
            self.own_file = False
39
40
    def writekvs(self, kvs):
41
        # Create strings for printing
42
        key2str = {}
43
        for (key, val) in sorted(kvs.items()):
44
            if isinstance(val, float):
45
                valstr = '%-8.3g' % (val,)
46
            else:
47
                valstr = str(val)
48
            key2str[self._truncate(key)] = self._truncate(valstr)
49
50
        # Find max widths
51
        if len(key2str) == 0:
52
            print('WARNING: tried to write empty key-value dict')
53
            return
54
        else:
55
            keywidth = max(map(len, key2str.keys()))
56
            valwidth = max(map(len, key2str.values()))
57
58
        # Write out the data
59
        dashes = '-' * (keywidth + valwidth + 7)
60
        lines = [dashes]
61
        for (key, val) in sorted(key2str.items()):
62
            lines.append('| %s%s | %s%s |' % (
63
                key,
64
                ' ' * (keywidth - len(key)),
65
                val,
66
                ' ' * (valwidth - len(val)),
67
            ))
68
        lines.append(dashes)
69
        self.file.write('\n'.join(lines) + '\n')
70
71
        # Flush the output to the file
72
        self.file.flush()
73
74
    def _truncate(self, s):
75
        return s[:20] + '...' if len(s) > 23 else s
76
77
    def writeseq(self, seq):
78
        for arg in seq:
79
            self.file.write(arg)
80
        self.file.write('\n')
81
        self.file.flush()
82
83
    def close(self):
84
        if self.own_file:
85
            self.file.close()
86
87
class JSONOutputFormat(KVWriter):
88
    def __init__(self, filename):
89
        self.file = open(filename, 'wt')
90
91
    def writekvs(self, kvs):
92
        for k, v in sorted(kvs.items()):
93
            if hasattr(v, 'dtype'):
94
                v = v.tolist()
95
                kvs[k] = float(v)
96
        self.file.write(json.dumps(kvs) + '\n')
97
        self.file.flush()
98
99
    def close(self):
100
        self.file.close()
101
102
class CSVOutputFormat(KVWriter):
103
    def __init__(self, filename):
104
        self.file = open(filename, 'w+t')
105
        self.keys = []
106
        self.sep = ','
107
108
    def writekvs(self, kvs):
109
        # Add our current row to the history
110
        extra_keys = kvs.keys() - self.keys
111
        if extra_keys:
112
            self.keys.extend(extra_keys)
113
            self.file.seek(0)
114
            lines = self.file.readlines()
115
            self.file.seek(0)
116
            for (i, k) in enumerate(self.keys):
117
                if i > 0:
118
                    self.file.write(',')
119
                self.file.write(k)
120
            self.file.write('\n')
121
            for line in lines[1:]:
122
                self.file.write(line[:-1])
123
                self.file.write(self.sep * len(extra_keys))
124
                self.file.write('\n')
125
        for (i, k) in enumerate(self.keys):
126
            if i > 0:
127
                self.file.write(',')
128
            v = kvs.get(k)
129
            if v is not None:
130
                self.file.write(str(v))
131
        self.file.write('\n')
132
        self.file.flush()
133
134
    def close(self):
135
        self.file.close()
136
137
138
class TensorBoardOutputFormat(KVWriter):
139
    """
140
    Dumps key/value pairs into TensorBoard's numeric format.
141
    """
142
    def __init__(self, dir):
143
        os.makedirs(dir, exist_ok=True)
144
        self.dir = dir
145
        self.step = 1
146
        prefix = 'events'
147
        path = osp.join(osp.abspath(dir), prefix)
148
        import tensorflow as tf
149
        from tensorflow.python import pywrap_tensorflow
150
        from tensorflow.core.util import event_pb2
151
        from tensorflow.python.util import compat
152
        self.tf = tf
153
        self.event_pb2 = event_pb2
154
        self.pywrap_tensorflow = pywrap_tensorflow
155
        self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
156
157
    def writekvs(self, kvs):
158
        def summary_val(k, v):
159
            kwargs = {'tag': k, 'simple_value': float(v)}
160
            return self.tf.Summary.Value(**kwargs)
161
        summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
162
        event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
163
        event.step = self.step # is there any reason why you'd want to specify the step?
164
        self.writer.WriteEvent(event)
165
        self.writer.Flush()
166
        self.step += 1
167
168
    def close(self):
169
        if self.writer:
170
            self.writer.Close()
171
            self.writer = None
172
173
def make_output_format(format, ev_dir, log_suffix=''):
174
    os.makedirs(ev_dir, exist_ok=True)
175
    if format == 'stdout':
176
        return HumanOutputFormat(sys.stdout)
177
    elif format == 'log':
178
        return HumanOutputFormat(osp.join(ev_dir, 'log%s.txt' % log_suffix))
179
    elif format == 'json':
180
        return JSONOutputFormat(osp.join(ev_dir, 'progress%s.json' % log_suffix))
181
    elif format == 'csv':
182
        return CSVOutputFormat(osp.join(ev_dir, 'progress%s.csv' % log_suffix))
183
    elif format == 'tensorboard':
184
        return TensorBoardOutputFormat(osp.join(ev_dir, 'tb%s' % log_suffix))
185
    else:
186
        raise ValueError('Unknown format specified: %s' % (format,))
187
188
# ================================================================
189
# API
190
# ================================================================
191
192
def logkv(key, val):
193
    """
194
    Log a value of some diagnostic
195
    Call this once for each diagnostic quantity, each iteration
196
    If called many times, last value will be used.
197
    """
198
    Logger.CURRENT.logkv(key, val)
199
200
def logkv_mean(key, val):
201
    """
202
    The same as logkv(), but if called many times, values averaged.
203
    """
204
    Logger.CURRENT.logkv_mean(key, val)
205
206
def logkvs(d):
207
    """
208
    Log a dictionary of key-value pairs
209
    """
210
    for (k, v) in d.items():
211
        logkv(k, v)
212
213
def dumpkvs():
214
    """
215
    Write all of the diagnostics from the current iteration
216
217
    level: int. (see logger.py docs) If the global logger level is higher than
218
                the level argument here, don't print to stdout.
219
    """
220
    Logger.CURRENT.dumpkvs()
221
222
def getkvs():
223
    return Logger.CURRENT.name2val
224
225
226
def log(*args, level=INFO):
227
    """
228
    Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
229
    """
230
    Logger.CURRENT.log(*args, level=level)
231
232
def debug(*args):
233
    log(*args, level=DEBUG)
234
235
def info(*args):
236
    log(*args, level=INFO)
237
238
def warn(*args):
239
    log(*args, level=WARN)
240
241
def error(*args):
242
    log(*args, level=ERROR)
243
244
245
def set_level(level):
246
    """
247
    Set logging threshold on current logger.
248
    """
249
    Logger.CURRENT.set_level(level)
250
251
def get_dir():
252
    """
253
    Get directory that log files are being written to.
254
    will be None if there is no output directory (i.e., if you didn't call start)
255
    """
256
    return Logger.CURRENT.get_dir()
257
258
record_tabular = logkv
259
dump_tabular = dumpkvs
260
261
class ProfileKV:
262
    """
263
    Usage:
264
    with logger.ProfileKV("interesting_scope"):
265
        code
266
    """
267
    def __init__(self, n):
268
        self.n = "wait_" + n
269
    def __enter__(self):
270
        self.t1 = time.time()
271
    def __exit__(self ,type, value, traceback):
272
        Logger.CURRENT.name2val[self.n] += time.time() - self.t1
273
274
def profile(n):
275
    """
276
    Usage:
277
    @profile("my_func")
278
    def my_func(): code
279
    """
280
    def decorator_with_name(func):
281
        def func_wrapper(*args, **kwargs):
282
            with ProfileKV(n):
283
                return func(*args, **kwargs)
284
        return func_wrapper
285
    return decorator_with_name
286
287
288
# ================================================================
289
# Backend
290
# ================================================================
291
292
class Logger(object):
293
    DEFAULT = None  # A logger with no output files. (See right below class definition)
294
                    # So that you can still log to the terminal without setting up any output files
295
    CURRENT = None  # Current logger being used by the free functions above
296
297
    def __init__(self, dir, output_formats):
298
        self.name2val = defaultdict(float)  # values this iteration
299
        self.name2cnt = defaultdict(int)
300
        self.level = INFO
301
        self.dir = dir
302
        self.output_formats = output_formats
303
304
    # Logging API, forwarded
305
    # ----------------------------------------
306
    def logkv(self, key, val):
307
        self.name2val[key] = val
308
309
    def logkv_mean(self, key, val):
310
        if val is None:
311
            self.name2val[key] = None
312
            return
313
        oldval, cnt = self.name2val[key], self.name2cnt[key]
314
        self.name2val[key] = oldval*cnt/(cnt+1) + val/(cnt+1)
315
        self.name2cnt[key] = cnt + 1
316
317
    def dumpkvs(self):
318
        if self.level == DISABLED: return
319
        for fmt in self.output_formats:
320
            if isinstance(fmt, KVWriter):
321
                fmt.writekvs(self.name2val)
322
        self.name2val.clear()
323
        self.name2cnt.clear()
324
325
    def log(self, *args, level=INFO):
326
        if self.level <= level:
327
            self._do_log(args)
328
329
    # Configuration
330
    # ----------------------------------------
331
    def set_level(self, level):
332
        self.level = level
333
334
    def get_dir(self):
335
        return self.dir
336
337
    def close(self):
338
        for fmt in self.output_formats:
339
            fmt.close()
340
341
    # Misc
342
    # ----------------------------------------
343
    def _do_log(self, args):
344
        for fmt in self.output_formats:
345
            if isinstance(fmt, SeqWriter):
346
                fmt.writeseq(map(str, args))
347
348
Logger.DEFAULT = Logger.CURRENT = Logger(dir=None, output_formats=[HumanOutputFormat(sys.stdout)])
349
350
def configure(dir=None, format_strs=None):
351
    if dir is None:
352
        dir = os.getenv('OPENAI_LOGDIR')
353
    if dir is None:
354
        dir = osp.join(tempfile.gettempdir(),
355
            datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"))
356
    assert isinstance(dir, str)
357
    os.makedirs(dir, exist_ok=True)
358
359
    log_suffix = ''
360
    from mpi4py import MPI
361
    rank = MPI.COMM_WORLD.Get_rank()
362
    if rank > 0:
363
        log_suffix = "-rank%03i" % rank
364
365
    if format_strs is None:
366
        strs, strs_mpi = os.getenv('OPENAI_LOG_FORMAT'), os.getenv('OPENAI_LOG_FORMAT_MPI')
367
        format_strs = strs_mpi if rank>0 else strs
368
        if format_strs is not None:
369
            format_strs = format_strs.split(',')
370
        else:
371
            format_strs = LOG_OUTPUT_FORMATS_MPI if rank>0 else LOG_OUTPUT_FORMATS
372
373
    output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
374
375
    Logger.CURRENT = Logger(dir=dir, output_formats=output_formats)
376
    log('Logging to %s'%dir)
377
378
def reset():
379
    if Logger.CURRENT is not Logger.DEFAULT:
380
        Logger.CURRENT.close()
381
        Logger.CURRENT = Logger.DEFAULT
382
        log('Reset logger')
383
384
class scoped_configure(object):
385
    def __init__(self, dir=None, format_strs=None):
386
        self.dir = dir
387
        self.format_strs = format_strs
388
        self.prevlogger = None
389
    def __enter__(self):
390
        self.prevlogger = Logger.CURRENT
391
        configure(dir=self.dir, format_strs=self.format_strs)
392
    def __exit__(self, *args):
393
        Logger.CURRENT.close()
394
        Logger.CURRENT = self.prevlogger
395
396
# ================================================================
397
398
def _demo():
399
    info("hi")
400
    debug("shouldn't appear")
401
    set_level(DEBUG)
402
    debug("should appear")
403
    dir = "/tmp/testlogging"
404
    if os.path.exists(dir):
405
        shutil.rmtree(dir)
406
    configure(dir=dir)
407
    logkv("a", 3)
408
    logkv("b", 2.5)
409
    dumpkvs()
410
    logkv("b", -2.5)
411
    logkv("a", 5.5)
412
    dumpkvs()
413
    info("^^^ should see a = 5.5")
414
    logkv_mean("b", -22.5)
415
    logkv_mean("b", -44.4)
416
    logkv("a", 5.5)
417
    dumpkvs()
418
    info("^^^ should see b = 33.3")
419
420
    logkv("b", -2.5)
421
    dumpkvs()
422
423
    logkv("a", "longasslongasslongasslongasslongasslongassvalue")
424
    dumpkvs()
425
426
427
# ================================================================
428
# Readers
429
# ================================================================
430
431
def read_json(fname):
432
    import pandas
433
    ds = []
434
    with open(fname, 'rt') as fh:
435
        for line in fh:
436
            ds.append(json.loads(line))
437
    return pandas.DataFrame(ds)
438
439
def read_csv(fname):
440
    import pandas
441
    return pandas.read_csv(fname, index_col=None, comment='#')
442
443
def read_tb(path):
444
    """
445
    path : a tensorboard file OR a directory, where we will find all TB files
446
           of the form events.*
447
    """
448
    import pandas
449
    import numpy as np
450
    from glob import glob
451
    from collections import defaultdict
452
    import tensorflow as tf
453
    if osp.isdir(path):
454
        fnames = glob(osp.join(path, "events.*"))
455
    elif osp.basename(path).startswith("events."):
456
        fnames = [path]
457
    else:
458
        raise NotImplementedError("Expected tensorboard file or directory containing them. Got %s"%path)
459
    tag2pairs = defaultdict(list)
460
    maxstep = 0
461
    for fname in fnames:
462
        for summary in tf.train.summary_iterator(fname):
463
            if summary.step > 0:
464
                for v in summary.summary.value:
465
                    pair = (summary.step, v.simple_value)
466
                    tag2pairs[v.tag].append(pair)
467
                maxstep = max(summary.step, maxstep)
468
    data = np.empty((maxstep, len(tag2pairs)))
469
    data[:] = np.nan
470
    tags = sorted(tag2pairs.keys())
471
    for (colidx,tag) in enumerate(tags):
472
        pairs = tag2pairs[tag]
473
        for (step, value) in pairs:
474
            data[step-1, colidx] = value
475
    return pandas.DataFrame(data, columns=tags)
476
477
if __name__ == "__main__":
478
    _demo()