--- a +++ b/tools/print_f1.py @@ -0,0 +1,88 @@ +# Copyright 2018 Google LLC. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +r"""Parse and extract metrics from *.metrics files.""" + +import argparse +import json +import logging +from os import listdir +from os.path import isfile +from os.path import join +import re + + +def parse_cmdline(argv): + """Parse the commandline.""" + parser = argparse.ArgumentParser() + + parser.add_argument( + '--metrics_dir', help='Path to the directory with metrics files.') + + known_args, _ = parser.parse_known_args(argv) + + return known_args + + +def extract_checkpoint_number_from_metrics_filename(filename): + match = re.search(r'ckpt-([\d]*)\.metrics', filename) + if match: + return int(match.group(1)) + + +def read_metrics_file(path): + """Reads metrics f and outputs metrics in a dict.""" + with open(path) as f: + metrics = { + key.replace('/', '_'): float(value) + for key, value in json.loads(f.read()).items() + } + metrics['checkpoint'] = extract_checkpoint_number_from_metrics_filename(path) + metrics['F1_All'] = 2 * metrics['TPs_All'] / ( + 2 * metrics['TPs_All'] + metrics['FNs_All'] + metrics['FPs_All']) + metrics['TPs+FNs_All'] = metrics['TPs_All'] + metrics['FNs_All'] + return metrics + + +def main(argv=None): + """Main entry point.""" + known_args = parse_cmdline(argv) + metrics_dir = known_args.metrics_dir + metrics_files = [ + join(metrics_dir, f) + for f in listdir(metrics_dir) + if isfile(join(metrics_dir, f)) + ] + metrics = [read_metrics_file(f) for f in metrics_files] + for m in metrics: + print('%s\t%s\t%s' % (m['checkpoint'], m['TPs+FNs_All'], m['F1_All'])) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + main()