Switch to unified view

a b/local_test/baselines/logger.py
1
import os
2
import sys
3
import shutil
4
import os.path as osp
5
import json
6
import time
7
import datetime
8
import tempfile
9
from collections import defaultdict
10
11
DEBUG = 10
12
INFO = 20
13
WARN = 30
14
ERROR = 40
15
16
DISABLED = 50
17
18
class KVWriter(object):
19
    def writekvs(self, kvs):
20
        raise NotImplementedError
21
22
class SeqWriter(object):
23
    def writeseq(self, seq):
24
        raise NotImplementedError
25
26
class HumanOutputFormat(KVWriter, SeqWriter):
27
    def __init__(self, filename_or_file):
28
        if isinstance(filename_or_file, str):
29
            self.file = open(filename_or_file, 'wt')
30
            self.own_file = True
31
        else:
32
            assert hasattr(filename_or_file, 'read'), 'expected file or str, got %s'%filename_or_file
33
            self.file = filename_or_file
34
            self.own_file = False
35
36
    def writekvs(self, kvs):
37
        # Create strings for printing
38
        key2str = {}
39
        for (key, val) in sorted(kvs.items()):
40
            if isinstance(val, float):
41
                valstr = '%-8.3g' % (val,)
42
            else:
43
                valstr = str(val)
44
            key2str[self._truncate(key)] = self._truncate(valstr)
45
46
        # Find max widths
47
        if len(key2str) == 0:
48
            print('WARNING: tried to write empty key-value dict')
49
            return
50
        else:
51
            keywidth = max(map(len, key2str.keys()))
52
            valwidth = max(map(len, key2str.values()))
53
54
        # Write out the data
55
        dashes = '-' * (keywidth + valwidth + 7)
56
        lines = [dashes]
57
        for (key, val) in sorted(key2str.items()):
58
            lines.append('| %s%s | %s%s |' % (
59
                key,
60
                ' ' * (keywidth - len(key)),
61
                val,
62
                ' ' * (valwidth - len(val)),
63
            ))
64
        lines.append(dashes)
65
        self.file.write('\n'.join(lines) + '\n')
66
67
        # Flush the output to the file
68
        self.file.flush()
69
70
    def _truncate(self, s):
71
        return s[:20] + '...' if len(s) > 23 else s
72
73
    def writeseq(self, seq):
74
        seq = list(seq)
75
        for (i, elem) in enumerate(seq):
76
            self.file.write(elem)
77
            if i < len(seq) - 1: # add space unless this is the last one
78
                self.file.write(' ')
79
        self.file.write('\n')
80
        self.file.flush()
81
82
    def close(self):
83
        if self.own_file:
84
            self.file.close()
85
86
class JSONOutputFormat(KVWriter):
87
    def __init__(self, filename):
88
        self.file = open(filename, 'wt')
89
90
    def writekvs(self, kvs):
91
        for k, v in sorted(kvs.items()):
92
            if hasattr(v, 'dtype'):
93
                v = v.tolist()
94
                kvs[k] = float(v)
95
        self.file.write(json.dumps(kvs) + '\n')
96
        self.file.flush()
97
98
    def close(self):
99
        self.file.close()
100
101
class CSVOutputFormat(KVWriter):
102
    def __init__(self, filename):
103
        self.file = open(filename, 'w+t')
104
        self.keys = []
105
        self.sep = ','
106
107
    def writekvs(self, kvs):
108
        # Add our current row to the history
109
        extra_keys = kvs.keys() - self.keys
110
        if extra_keys:
111
            self.keys.extend(extra_keys)
112
            self.file.seek(0)
113
            lines = self.file.readlines()
114
            self.file.seek(0)
115
            for (i, k) in enumerate(self.keys):
116
                if i > 0:
117
                    self.file.write(',')
118
                self.file.write(k)
119
            self.file.write('\n')
120
            for line in lines[1:]:
121
                self.file.write(line[:-1])
122
                self.file.write(self.sep * len(extra_keys))
123
                self.file.write('\n')
124
        for (i, k) in enumerate(self.keys):
125
            if i > 0:
126
                self.file.write(',')
127
            v = kvs.get(k)
128
            if v is not None:
129
                self.file.write(str(v))
130
        self.file.write('\n')
131
        self.file.flush()
132
133
    def close(self):
134
        self.file.close()
135
136
137
class TensorBoardOutputFormat(KVWriter):
138
    """
139
    Dumps key/value pairs into TensorBoard's numeric format.
140
    """
141
    def __init__(self, dir):
142
        os.makedirs(dir, exist_ok=True)
143
        self.dir = dir
144
        self.step = 1
145
        prefix = 'events'
146
        path = osp.join(osp.abspath(dir), prefix)
147
        import tensorflow as tf
148
        from tensorflow.python import pywrap_tensorflow
149
        from tensorflow.core.util import event_pb2
150
        from tensorflow.python.util import compat
151
        self.tf = tf
152
        self.event_pb2 = event_pb2
153
        self.pywrap_tensorflow = pywrap_tensorflow
154
        self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
155
156
    def writekvs(self, kvs):
157
        def summary_val(k, v):
158
            kwargs = {'tag': k, 'simple_value': float(v)}
159
            return self.tf.Summary.Value(**kwargs)
160
        summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
161
        event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
162
        event.step = self.step # is there any reason why you'd want to specify the step?
163
        self.writer.WriteEvent(event)
164
        self.writer.Flush()
165
        self.step += 1
166
167
    def close(self):
168
        if self.writer:
169
            self.writer.Close()
170
            self.writer = None
171
172
def make_output_format(format, ev_dir, log_suffix=''):
173
    os.makedirs(ev_dir, exist_ok=True)
174
    if format == 'stdout':
175
        return HumanOutputFormat(sys.stdout)
176
    elif format == 'log':
177
        return HumanOutputFormat(osp.join(ev_dir, 'log%s.txt' % log_suffix))
178
    elif format == 'json':
179
        return JSONOutputFormat(osp.join(ev_dir, 'progress%s.json' % log_suffix))
180
    elif format == 'csv':
181
        return CSVOutputFormat(osp.join(ev_dir, 'progress%s.csv' % log_suffix))
182
    elif format == 'tensorboard':
183
        return TensorBoardOutputFormat(osp.join(ev_dir, 'tb%s' % log_suffix))
184
    else:
185
        raise ValueError('Unknown format specified: %s' % (format,))
186
187
# ================================================================
188
# API
189
# ================================================================
190
191
def logkv(key, val):
192
    """
193
    Log a value of some diagnostic
194
    Call this once for each diagnostic quantity, each iteration
195
    If called many times, last value will be used.
196
    """
197
    Logger.CURRENT.logkv(key, val)
198
199
def logkv_mean(key, val):
200
    """
201
    The same as logkv(), but if called many times, values averaged.
202
    """
203
    Logger.CURRENT.logkv_mean(key, val)
204
205
def logkvs(d):
206
    """
207
    Log a dictionary of key-value pairs
208
    """
209
    for (k, v) in d.items():
210
        logkv(k, v)
211
212
def dumpkvs():
213
    """
214
    Write all of the diagnostics from the current iteration
215
216
    level: int. (see logger.py docs) If the global logger level is higher than
217
                the level argument here, don't print to stdout.
218
    """
219
    Logger.CURRENT.dumpkvs()
220
221
def getkvs():
222
    return Logger.CURRENT.name2val
223
224
225
def log(*args, level=INFO):
226
    """
227
    Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
228
    """
229
    Logger.CURRENT.log(*args, level=level)
230
231
def debug(*args):
232
    log(*args, level=DEBUG)
233
234
def info(*args):
235
    log(*args, level=INFO)
236
237
def warn(*args):
238
    log(*args, level=WARN)
239
240
def error(*args):
241
    log(*args, level=ERROR)
242
243
244
def set_level(level):
245
    """
246
    Set logging threshold on current logger.
247
    """
248
    Logger.CURRENT.set_level(level)
249
250
def get_dir():
251
    """
252
    Get directory that log files are being written to.
253
    will be None if there is no output directory (i.e., if you didn't call start)
254
    """
255
    return Logger.CURRENT.get_dir()
256
257
record_tabular = logkv
258
dump_tabular = dumpkvs
259
260
class ProfileKV:
261
    """
262
    Usage:
263
    with logger.ProfileKV("interesting_scope"):
264
        code
265
    """
266
    def __init__(self, n):
267
        self.n = "wait_" + n
268
    def __enter__(self):
269
        self.t1 = time.time()
270
    def __exit__(self ,type, value, traceback):
271
        Logger.CURRENT.name2val[self.n] += time.time() - self.t1
272
273
def profile(n):
274
    """
275
    Usage:
276
    @profile("my_func")
277
    def my_func(): code
278
    """
279
    def decorator_with_name(func):
280
        def func_wrapper(*args, **kwargs):
281
            with ProfileKV(n):
282
                return func(*args, **kwargs)
283
        return func_wrapper
284
    return decorator_with_name
285
286
287
# ================================================================
288
# Backend
289
# ================================================================
290
291
class Logger(object):
292
    DEFAULT = None  # A logger with no output files. (See right below class definition)
293
                    # So that you can still log to the terminal without setting up any output files
294
    CURRENT = None  # Current logger being used by the free functions above
295
296
    def __init__(self, dir, output_formats):
297
        self.name2val = defaultdict(float)  # values this iteration
298
        self.name2cnt = defaultdict(int)
299
        self.level = INFO
300
        self.dir = dir
301
        self.output_formats = output_formats
302
303
    # Logging API, forwarded
304
    # ----------------------------------------
305
    def logkv(self, key, val):
306
        self.name2val[key] = val
307
308
    def logkv_mean(self, key, val):
309
        if val is None:
310
            self.name2val[key] = None
311
            return
312
        oldval, cnt = self.name2val[key], self.name2cnt[key]
313
        self.name2val[key] = oldval*cnt/(cnt+1) + val/(cnt+1)
314
        self.name2cnt[key] = cnt + 1
315
316
    def dumpkvs(self):
317
        if self.level == DISABLED: return
318
        for fmt in self.output_formats:
319
            if isinstance(fmt, KVWriter):
320
                fmt.writekvs(self.name2val)
321
        self.name2val.clear()
322
        self.name2cnt.clear()
323
324
    def log(self, *args, level=INFO):
325
        if self.level <= level:
326
            self._do_log(args)
327
328
    # Configuration
329
    # ----------------------------------------
330
    def set_level(self, level):
331
        self.level = level
332
333
    def get_dir(self):
334
        return self.dir
335
336
    def close(self):
337
        for fmt in self.output_formats:
338
            fmt.close()
339
340
    # Misc
341
    # ----------------------------------------
342
    def _do_log(self, args):
343
        for fmt in self.output_formats:
344
            if isinstance(fmt, SeqWriter):
345
                fmt.writeseq(map(str, args))
346
347
Logger.DEFAULT = Logger.CURRENT = Logger(dir=None, output_formats=[HumanOutputFormat(sys.stdout)])
348
349
def configure(dir=None, format_strs=None):
350
    if dir is None:
351
        dir = os.getenv('OPENAI_LOGDIR')
352
    if dir is None:
353
        dir = osp.join(tempfile.gettempdir(),
354
            datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"))
355
    assert isinstance(dir, str)
356
    os.makedirs(dir, exist_ok=True)
357
358
    log_suffix = ''
359
    from mpi4py import MPI
360
    rank = MPI.COMM_WORLD.Get_rank()
361
    if rank > 0:
362
        log_suffix = "-rank%03i" % rank
363
364
    if format_strs is None:
365
        if rank == 0:
366
            format_strs = os.getenv('OPENAI_LOG_FORMAT', 'stdout,log,csv').split(',')
367
        else:
368
            format_strs = os.getenv('OPENAI_LOG_FORMAT_MPI', 'log').split(',')
369
    format_strs = filter(None, format_strs)
370
    output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
371
372
    Logger.CURRENT = Logger(dir=dir, output_formats=output_formats)
373
    log('Logging to %s'%dir)
374
375
def reset():
376
    if Logger.CURRENT is not Logger.DEFAULT:
377
        Logger.CURRENT.close()
378
        Logger.CURRENT = Logger.DEFAULT
379
        log('Reset logger')
380
381
class scoped_configure(object):
382
    def __init__(self, dir=None, format_strs=None):
383
        self.dir = dir
384
        self.format_strs = format_strs
385
        self.prevlogger = None
386
    def __enter__(self):
387
        self.prevlogger = Logger.CURRENT
388
        configure(dir=self.dir, format_strs=self.format_strs)
389
    def __exit__(self, *args):
390
        Logger.CURRENT.close()
391
        Logger.CURRENT = self.prevlogger
392
393
# ================================================================
394
395
def _demo():
396
    info("hi")
397
    debug("shouldn't appear")
398
    set_level(DEBUG)
399
    debug("should appear")
400
    dir = "/tmp/testlogging"
401
    if os.path.exists(dir):
402
        shutil.rmtree(dir)
403
    configure(dir=dir)
404
    logkv("a", 3)
405
    logkv("b", 2.5)
406
    dumpkvs()
407
    logkv("b", -2.5)
408
    logkv("a", 5.5)
409
    dumpkvs()
410
    info("^^^ should see a = 5.5")
411
    logkv_mean("b", -22.5)
412
    logkv_mean("b", -44.4)
413
    logkv("a", 5.5)
414
    dumpkvs()
415
    info("^^^ should see b = 33.3")
416
417
    logkv("b", -2.5)
418
    dumpkvs()
419
420
    logkv("a", "longasslongasslongasslongasslongasslongassvalue")
421
    dumpkvs()
422
423
424
# ================================================================
425
# Readers
426
# ================================================================
427
428
def read_json(fname):
429
    import pandas
430
    ds = []
431
    with open(fname, 'rt') as fh:
432
        for line in fh:
433
            ds.append(json.loads(line))
434
    return pandas.DataFrame(ds)
435
436
def read_csv(fname):
437
    import pandas
438
    return pandas.read_csv(fname, index_col=None, comment='#')
439
440
def read_tb(path):
441
    """
442
    path : a tensorboard file OR a directory, where we will find all TB files
443
           of the form events.*
444
    """
445
    import pandas
446
    import numpy as np
447
    from glob import glob
448
    from collections import defaultdict
449
    import tensorflow as tf
450
    if osp.isdir(path):
451
        fnames = glob(osp.join(path, "events.*"))
452
    elif osp.basename(path).startswith("events."):
453
        fnames = [path]
454
    else:
455
        raise NotImplementedError("Expected tensorboard file or directory containing them. Got %s"%path)
456
    tag2pairs = defaultdict(list)
457
    maxstep = 0
458
    for fname in fnames:
459
        for summary in tf.train.summary_iterator(fname):
460
            if summary.step > 0:
461
                for v in summary.summary.value:
462
                    pair = (summary.step, v.simple_value)
463
                    tag2pairs[v.tag].append(pair)
464
                maxstep = max(summary.step, maxstep)
465
    data = np.empty((maxstep, len(tag2pairs)))
466
    data[:] = np.nan
467
    tags = sorted(tag2pairs.keys())
468
    for (colidx,tag) in enumerate(tags):
469
        pairs = tag2pairs[tag]
470
        for (step, value) in pairs:
471
            data[step-1, colidx] = value
472
    return pandas.DataFrame(data, columns=tags)
473
474
if __name__ == "__main__":
475
    _demo()