|
a |
|
b/local_test/baselines/logger.py |
|
|
1 |
import os |
|
|
2 |
import sys |
|
|
3 |
import shutil |
|
|
4 |
import os.path as osp |
|
|
5 |
import json |
|
|
6 |
import time |
|
|
7 |
import datetime |
|
|
8 |
import tempfile |
|
|
9 |
from collections import defaultdict |
|
|
10 |
|
|
|
11 |
DEBUG = 10 |
|
|
12 |
INFO = 20 |
|
|
13 |
WARN = 30 |
|
|
14 |
ERROR = 40 |
|
|
15 |
|
|
|
16 |
DISABLED = 50 |
|
|
17 |
|
|
|
18 |
class KVWriter(object): |
|
|
19 |
def writekvs(self, kvs): |
|
|
20 |
raise NotImplementedError |
|
|
21 |
|
|
|
22 |
class SeqWriter(object): |
|
|
23 |
def writeseq(self, seq): |
|
|
24 |
raise NotImplementedError |
|
|
25 |
|
|
|
26 |
class HumanOutputFormat(KVWriter, SeqWriter): |
|
|
27 |
def __init__(self, filename_or_file): |
|
|
28 |
if isinstance(filename_or_file, str): |
|
|
29 |
self.file = open(filename_or_file, 'wt') |
|
|
30 |
self.own_file = True |
|
|
31 |
else: |
|
|
32 |
assert hasattr(filename_or_file, 'read'), 'expected file or str, got %s'%filename_or_file |
|
|
33 |
self.file = filename_or_file |
|
|
34 |
self.own_file = False |
|
|
35 |
|
|
|
36 |
def writekvs(self, kvs): |
|
|
37 |
# Create strings for printing |
|
|
38 |
key2str = {} |
|
|
39 |
for (key, val) in sorted(kvs.items()): |
|
|
40 |
if isinstance(val, float): |
|
|
41 |
valstr = '%-8.3g' % (val,) |
|
|
42 |
else: |
|
|
43 |
valstr = str(val) |
|
|
44 |
key2str[self._truncate(key)] = self._truncate(valstr) |
|
|
45 |
|
|
|
46 |
# Find max widths |
|
|
47 |
if len(key2str) == 0: |
|
|
48 |
print('WARNING: tried to write empty key-value dict') |
|
|
49 |
return |
|
|
50 |
else: |
|
|
51 |
keywidth = max(map(len, key2str.keys())) |
|
|
52 |
valwidth = max(map(len, key2str.values())) |
|
|
53 |
|
|
|
54 |
# Write out the data |
|
|
55 |
dashes = '-' * (keywidth + valwidth + 7) |
|
|
56 |
lines = [dashes] |
|
|
57 |
for (key, val) in sorted(key2str.items()): |
|
|
58 |
lines.append('| %s%s | %s%s |' % ( |
|
|
59 |
key, |
|
|
60 |
' ' * (keywidth - len(key)), |
|
|
61 |
val, |
|
|
62 |
' ' * (valwidth - len(val)), |
|
|
63 |
)) |
|
|
64 |
lines.append(dashes) |
|
|
65 |
self.file.write('\n'.join(lines) + '\n') |
|
|
66 |
|
|
|
67 |
# Flush the output to the file |
|
|
68 |
self.file.flush() |
|
|
69 |
|
|
|
70 |
def _truncate(self, s): |
|
|
71 |
return s[:20] + '...' if len(s) > 23 else s |
|
|
72 |
|
|
|
73 |
def writeseq(self, seq): |
|
|
74 |
seq = list(seq) |
|
|
75 |
for (i, elem) in enumerate(seq): |
|
|
76 |
self.file.write(elem) |
|
|
77 |
if i < len(seq) - 1: # add space unless this is the last one |
|
|
78 |
self.file.write(' ') |
|
|
79 |
self.file.write('\n') |
|
|
80 |
self.file.flush() |
|
|
81 |
|
|
|
82 |
def close(self): |
|
|
83 |
if self.own_file: |
|
|
84 |
self.file.close() |
|
|
85 |
|
|
|
86 |
class JSONOutputFormat(KVWriter): |
|
|
87 |
def __init__(self, filename): |
|
|
88 |
self.file = open(filename, 'wt') |
|
|
89 |
|
|
|
90 |
def writekvs(self, kvs): |
|
|
91 |
for k, v in sorted(kvs.items()): |
|
|
92 |
if hasattr(v, 'dtype'): |
|
|
93 |
v = v.tolist() |
|
|
94 |
kvs[k] = float(v) |
|
|
95 |
self.file.write(json.dumps(kvs) + '\n') |
|
|
96 |
self.file.flush() |
|
|
97 |
|
|
|
98 |
def close(self): |
|
|
99 |
self.file.close() |
|
|
100 |
|
|
|
101 |
class CSVOutputFormat(KVWriter): |
|
|
102 |
def __init__(self, filename): |
|
|
103 |
self.file = open(filename, 'w+t') |
|
|
104 |
self.keys = [] |
|
|
105 |
self.sep = ',' |
|
|
106 |
|
|
|
107 |
def writekvs(self, kvs): |
|
|
108 |
# Add our current row to the history |
|
|
109 |
extra_keys = kvs.keys() - self.keys |
|
|
110 |
if extra_keys: |
|
|
111 |
self.keys.extend(extra_keys) |
|
|
112 |
self.file.seek(0) |
|
|
113 |
lines = self.file.readlines() |
|
|
114 |
self.file.seek(0) |
|
|
115 |
for (i, k) in enumerate(self.keys): |
|
|
116 |
if i > 0: |
|
|
117 |
self.file.write(',') |
|
|
118 |
self.file.write(k) |
|
|
119 |
self.file.write('\n') |
|
|
120 |
for line in lines[1:]: |
|
|
121 |
self.file.write(line[:-1]) |
|
|
122 |
self.file.write(self.sep * len(extra_keys)) |
|
|
123 |
self.file.write('\n') |
|
|
124 |
for (i, k) in enumerate(self.keys): |
|
|
125 |
if i > 0: |
|
|
126 |
self.file.write(',') |
|
|
127 |
v = kvs.get(k) |
|
|
128 |
if v is not None: |
|
|
129 |
self.file.write(str(v)) |
|
|
130 |
self.file.write('\n') |
|
|
131 |
self.file.flush() |
|
|
132 |
|
|
|
133 |
def close(self): |
|
|
134 |
self.file.close() |
|
|
135 |
|
|
|
136 |
|
|
|
137 |
class TensorBoardOutputFormat(KVWriter): |
|
|
138 |
""" |
|
|
139 |
Dumps key/value pairs into TensorBoard's numeric format. |
|
|
140 |
""" |
|
|
141 |
def __init__(self, dir): |
|
|
142 |
os.makedirs(dir, exist_ok=True) |
|
|
143 |
self.dir = dir |
|
|
144 |
self.step = 1 |
|
|
145 |
prefix = 'events' |
|
|
146 |
path = osp.join(osp.abspath(dir), prefix) |
|
|
147 |
import tensorflow as tf |
|
|
148 |
from tensorflow.python import pywrap_tensorflow |
|
|
149 |
from tensorflow.core.util import event_pb2 |
|
|
150 |
from tensorflow.python.util import compat |
|
|
151 |
self.tf = tf |
|
|
152 |
self.event_pb2 = event_pb2 |
|
|
153 |
self.pywrap_tensorflow = pywrap_tensorflow |
|
|
154 |
self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) |
|
|
155 |
|
|
|
156 |
def writekvs(self, kvs): |
|
|
157 |
def summary_val(k, v): |
|
|
158 |
kwargs = {'tag': k, 'simple_value': float(v)} |
|
|
159 |
return self.tf.Summary.Value(**kwargs) |
|
|
160 |
summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) |
|
|
161 |
event = self.event_pb2.Event(wall_time=time.time(), summary=summary) |
|
|
162 |
event.step = self.step # is there any reason why you'd want to specify the step? |
|
|
163 |
self.writer.WriteEvent(event) |
|
|
164 |
self.writer.Flush() |
|
|
165 |
self.step += 1 |
|
|
166 |
|
|
|
167 |
def close(self): |
|
|
168 |
if self.writer: |
|
|
169 |
self.writer.Close() |
|
|
170 |
self.writer = None |
|
|
171 |
|
|
|
172 |
def make_output_format(format, ev_dir, log_suffix=''): |
|
|
173 |
os.makedirs(ev_dir, exist_ok=True) |
|
|
174 |
if format == 'stdout': |
|
|
175 |
return HumanOutputFormat(sys.stdout) |
|
|
176 |
elif format == 'log': |
|
|
177 |
return HumanOutputFormat(osp.join(ev_dir, 'log%s.txt' % log_suffix)) |
|
|
178 |
elif format == 'json': |
|
|
179 |
return JSONOutputFormat(osp.join(ev_dir, 'progress%s.json' % log_suffix)) |
|
|
180 |
elif format == 'csv': |
|
|
181 |
return CSVOutputFormat(osp.join(ev_dir, 'progress%s.csv' % log_suffix)) |
|
|
182 |
elif format == 'tensorboard': |
|
|
183 |
return TensorBoardOutputFormat(osp.join(ev_dir, 'tb%s' % log_suffix)) |
|
|
184 |
else: |
|
|
185 |
raise ValueError('Unknown format specified: %s' % (format,)) |
|
|
186 |
|
|
|
187 |
# ================================================================ |
|
|
188 |
# API |
|
|
189 |
# ================================================================ |
|
|
190 |
|
|
|
191 |
def logkv(key, val): |
|
|
192 |
""" |
|
|
193 |
Log a value of some diagnostic |
|
|
194 |
Call this once for each diagnostic quantity, each iteration |
|
|
195 |
If called many times, last value will be used. |
|
|
196 |
""" |
|
|
197 |
Logger.CURRENT.logkv(key, val) |
|
|
198 |
|
|
|
199 |
def logkv_mean(key, val): |
|
|
200 |
""" |
|
|
201 |
The same as logkv(), but if called many times, values averaged. |
|
|
202 |
""" |
|
|
203 |
Logger.CURRENT.logkv_mean(key, val) |
|
|
204 |
|
|
|
205 |
def logkvs(d): |
|
|
206 |
""" |
|
|
207 |
Log a dictionary of key-value pairs |
|
|
208 |
""" |
|
|
209 |
for (k, v) in d.items(): |
|
|
210 |
logkv(k, v) |
|
|
211 |
|
|
|
212 |
def dumpkvs(): |
|
|
213 |
""" |
|
|
214 |
Write all of the diagnostics from the current iteration |
|
|
215 |
|
|
|
216 |
level: int. (see logger.py docs) If the global logger level is higher than |
|
|
217 |
the level argument here, don't print to stdout. |
|
|
218 |
""" |
|
|
219 |
Logger.CURRENT.dumpkvs() |
|
|
220 |
|
|
|
221 |
def getkvs(): |
|
|
222 |
return Logger.CURRENT.name2val |
|
|
223 |
|
|
|
224 |
|
|
|
225 |
def log(*args, level=INFO): |
|
|
226 |
""" |
|
|
227 |
Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). |
|
|
228 |
""" |
|
|
229 |
Logger.CURRENT.log(*args, level=level) |
|
|
230 |
|
|
|
231 |
def debug(*args): |
|
|
232 |
log(*args, level=DEBUG) |
|
|
233 |
|
|
|
234 |
def info(*args): |
|
|
235 |
log(*args, level=INFO) |
|
|
236 |
|
|
|
237 |
def warn(*args): |
|
|
238 |
log(*args, level=WARN) |
|
|
239 |
|
|
|
240 |
def error(*args): |
|
|
241 |
log(*args, level=ERROR) |
|
|
242 |
|
|
|
243 |
|
|
|
244 |
def set_level(level): |
|
|
245 |
""" |
|
|
246 |
Set logging threshold on current logger. |
|
|
247 |
""" |
|
|
248 |
Logger.CURRENT.set_level(level) |
|
|
249 |
|
|
|
250 |
def get_dir(): |
|
|
251 |
""" |
|
|
252 |
Get directory that log files are being written to. |
|
|
253 |
will be None if there is no output directory (i.e., if you didn't call start) |
|
|
254 |
""" |
|
|
255 |
return Logger.CURRENT.get_dir() |
|
|
256 |
|
|
|
257 |
record_tabular = logkv |
|
|
258 |
dump_tabular = dumpkvs |
|
|
259 |
|
|
|
260 |
class ProfileKV: |
|
|
261 |
""" |
|
|
262 |
Usage: |
|
|
263 |
with logger.ProfileKV("interesting_scope"): |
|
|
264 |
code |
|
|
265 |
""" |
|
|
266 |
def __init__(self, n): |
|
|
267 |
self.n = "wait_" + n |
|
|
268 |
def __enter__(self): |
|
|
269 |
self.t1 = time.time() |
|
|
270 |
def __exit__(self ,type, value, traceback): |
|
|
271 |
Logger.CURRENT.name2val[self.n] += time.time() - self.t1 |
|
|
272 |
|
|
|
273 |
def profile(n): |
|
|
274 |
""" |
|
|
275 |
Usage: |
|
|
276 |
@profile("my_func") |
|
|
277 |
def my_func(): code |
|
|
278 |
""" |
|
|
279 |
def decorator_with_name(func): |
|
|
280 |
def func_wrapper(*args, **kwargs): |
|
|
281 |
with ProfileKV(n): |
|
|
282 |
return func(*args, **kwargs) |
|
|
283 |
return func_wrapper |
|
|
284 |
return decorator_with_name |
|
|
285 |
|
|
|
286 |
|
|
|
287 |
# ================================================================ |
|
|
288 |
# Backend |
|
|
289 |
# ================================================================ |
|
|
290 |
|
|
|
291 |
class Logger(object): |
|
|
292 |
DEFAULT = None # A logger with no output files. (See right below class definition) |
|
|
293 |
# So that you can still log to the terminal without setting up any output files |
|
|
294 |
CURRENT = None # Current logger being used by the free functions above |
|
|
295 |
|
|
|
296 |
def __init__(self, dir, output_formats): |
|
|
297 |
self.name2val = defaultdict(float) # values this iteration |
|
|
298 |
self.name2cnt = defaultdict(int) |
|
|
299 |
self.level = INFO |
|
|
300 |
self.dir = dir |
|
|
301 |
self.output_formats = output_formats |
|
|
302 |
|
|
|
303 |
# Logging API, forwarded |
|
|
304 |
# ---------------------------------------- |
|
|
305 |
def logkv(self, key, val): |
|
|
306 |
self.name2val[key] = val |
|
|
307 |
|
|
|
308 |
def logkv_mean(self, key, val): |
|
|
309 |
if val is None: |
|
|
310 |
self.name2val[key] = None |
|
|
311 |
return |
|
|
312 |
oldval, cnt = self.name2val[key], self.name2cnt[key] |
|
|
313 |
self.name2val[key] = oldval*cnt/(cnt+1) + val/(cnt+1) |
|
|
314 |
self.name2cnt[key] = cnt + 1 |
|
|
315 |
|
|
|
316 |
def dumpkvs(self): |
|
|
317 |
if self.level == DISABLED: return |
|
|
318 |
for fmt in self.output_formats: |
|
|
319 |
if isinstance(fmt, KVWriter): |
|
|
320 |
fmt.writekvs(self.name2val) |
|
|
321 |
self.name2val.clear() |
|
|
322 |
self.name2cnt.clear() |
|
|
323 |
|
|
|
324 |
def log(self, *args, level=INFO): |
|
|
325 |
if self.level <= level: |
|
|
326 |
self._do_log(args) |
|
|
327 |
|
|
|
328 |
# Configuration |
|
|
329 |
# ---------------------------------------- |
|
|
330 |
def set_level(self, level): |
|
|
331 |
self.level = level |
|
|
332 |
|
|
|
333 |
def get_dir(self): |
|
|
334 |
return self.dir |
|
|
335 |
|
|
|
336 |
def close(self): |
|
|
337 |
for fmt in self.output_formats: |
|
|
338 |
fmt.close() |
|
|
339 |
|
|
|
340 |
# Misc |
|
|
341 |
# ---------------------------------------- |
|
|
342 |
def _do_log(self, args): |
|
|
343 |
for fmt in self.output_formats: |
|
|
344 |
if isinstance(fmt, SeqWriter): |
|
|
345 |
fmt.writeseq(map(str, args)) |
|
|
346 |
|
|
|
347 |
Logger.DEFAULT = Logger.CURRENT = Logger(dir=None, output_formats=[HumanOutputFormat(sys.stdout)]) |
|
|
348 |
|
|
|
349 |
def configure(dir=None, format_strs=None): |
|
|
350 |
if dir is None: |
|
|
351 |
dir = os.getenv('OPENAI_LOGDIR') |
|
|
352 |
if dir is None: |
|
|
353 |
dir = osp.join(tempfile.gettempdir(), |
|
|
354 |
datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f")) |
|
|
355 |
assert isinstance(dir, str) |
|
|
356 |
os.makedirs(dir, exist_ok=True) |
|
|
357 |
|
|
|
358 |
log_suffix = '' |
|
|
359 |
from mpi4py import MPI |
|
|
360 |
rank = MPI.COMM_WORLD.Get_rank() |
|
|
361 |
if rank > 0: |
|
|
362 |
log_suffix = "-rank%03i" % rank |
|
|
363 |
|
|
|
364 |
if format_strs is None: |
|
|
365 |
if rank == 0: |
|
|
366 |
format_strs = os.getenv('OPENAI_LOG_FORMAT', 'stdout,log,csv').split(',') |
|
|
367 |
else: |
|
|
368 |
format_strs = os.getenv('OPENAI_LOG_FORMAT_MPI', 'log').split(',') |
|
|
369 |
format_strs = filter(None, format_strs) |
|
|
370 |
output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] |
|
|
371 |
|
|
|
372 |
Logger.CURRENT = Logger(dir=dir, output_formats=output_formats) |
|
|
373 |
log('Logging to %s'%dir) |
|
|
374 |
|
|
|
375 |
def reset(): |
|
|
376 |
if Logger.CURRENT is not Logger.DEFAULT: |
|
|
377 |
Logger.CURRENT.close() |
|
|
378 |
Logger.CURRENT = Logger.DEFAULT |
|
|
379 |
log('Reset logger') |
|
|
380 |
|
|
|
381 |
class scoped_configure(object): |
|
|
382 |
def __init__(self, dir=None, format_strs=None): |
|
|
383 |
self.dir = dir |
|
|
384 |
self.format_strs = format_strs |
|
|
385 |
self.prevlogger = None |
|
|
386 |
def __enter__(self): |
|
|
387 |
self.prevlogger = Logger.CURRENT |
|
|
388 |
configure(dir=self.dir, format_strs=self.format_strs) |
|
|
389 |
def __exit__(self, *args): |
|
|
390 |
Logger.CURRENT.close() |
|
|
391 |
Logger.CURRENT = self.prevlogger |
|
|
392 |
|
|
|
393 |
# ================================================================ |
|
|
394 |
|
|
|
395 |
def _demo(): |
|
|
396 |
info("hi") |
|
|
397 |
debug("shouldn't appear") |
|
|
398 |
set_level(DEBUG) |
|
|
399 |
debug("should appear") |
|
|
400 |
dir = "/tmp/testlogging" |
|
|
401 |
if os.path.exists(dir): |
|
|
402 |
shutil.rmtree(dir) |
|
|
403 |
configure(dir=dir) |
|
|
404 |
logkv("a", 3) |
|
|
405 |
logkv("b", 2.5) |
|
|
406 |
dumpkvs() |
|
|
407 |
logkv("b", -2.5) |
|
|
408 |
logkv("a", 5.5) |
|
|
409 |
dumpkvs() |
|
|
410 |
info("^^^ should see a = 5.5") |
|
|
411 |
logkv_mean("b", -22.5) |
|
|
412 |
logkv_mean("b", -44.4) |
|
|
413 |
logkv("a", 5.5) |
|
|
414 |
dumpkvs() |
|
|
415 |
info("^^^ should see b = 33.3") |
|
|
416 |
|
|
|
417 |
logkv("b", -2.5) |
|
|
418 |
dumpkvs() |
|
|
419 |
|
|
|
420 |
logkv("a", "longasslongasslongasslongasslongasslongassvalue") |
|
|
421 |
dumpkvs() |
|
|
422 |
|
|
|
423 |
|
|
|
424 |
# ================================================================ |
|
|
425 |
# Readers |
|
|
426 |
# ================================================================ |
|
|
427 |
|
|
|
428 |
def read_json(fname): |
|
|
429 |
import pandas |
|
|
430 |
ds = [] |
|
|
431 |
with open(fname, 'rt') as fh: |
|
|
432 |
for line in fh: |
|
|
433 |
ds.append(json.loads(line)) |
|
|
434 |
return pandas.DataFrame(ds) |
|
|
435 |
|
|
|
436 |
def read_csv(fname): |
|
|
437 |
import pandas |
|
|
438 |
return pandas.read_csv(fname, index_col=None, comment='#') |
|
|
439 |
|
|
|
440 |
def read_tb(path): |
|
|
441 |
""" |
|
|
442 |
path : a tensorboard file OR a directory, where we will find all TB files |
|
|
443 |
of the form events.* |
|
|
444 |
""" |
|
|
445 |
import pandas |
|
|
446 |
import numpy as np |
|
|
447 |
from glob import glob |
|
|
448 |
from collections import defaultdict |
|
|
449 |
import tensorflow as tf |
|
|
450 |
if osp.isdir(path): |
|
|
451 |
fnames = glob(osp.join(path, "events.*")) |
|
|
452 |
elif osp.basename(path).startswith("events."): |
|
|
453 |
fnames = [path] |
|
|
454 |
else: |
|
|
455 |
raise NotImplementedError("Expected tensorboard file or directory containing them. Got %s"%path) |
|
|
456 |
tag2pairs = defaultdict(list) |
|
|
457 |
maxstep = 0 |
|
|
458 |
for fname in fnames: |
|
|
459 |
for summary in tf.train.summary_iterator(fname): |
|
|
460 |
if summary.step > 0: |
|
|
461 |
for v in summary.summary.value: |
|
|
462 |
pair = (summary.step, v.simple_value) |
|
|
463 |
tag2pairs[v.tag].append(pair) |
|
|
464 |
maxstep = max(summary.step, maxstep) |
|
|
465 |
data = np.empty((maxstep, len(tag2pairs))) |
|
|
466 |
data[:] = np.nan |
|
|
467 |
tags = sorted(tag2pairs.keys()) |
|
|
468 |
for (colidx,tag) in enumerate(tags): |
|
|
469 |
pairs = tag2pairs[tag] |
|
|
470 |
for (step, value) in pairs: |
|
|
471 |
data[step-1, colidx] = value |
|
|
472 |
return pandas.DataFrame(data, columns=tags) |
|
|
473 |
|
|
|
474 |
if __name__ == "__main__": |
|
|
475 |
_demo() |