a b/lightning/util.py
1
"""Define Logger class for logging information to stdout and disk."""
2
import json
3
import os
4
from os.path import join
5
from pytorch_lightning.loggers import WandbLogger
6
from pytorch_lightning.loggers.test_tube import TestTubeLogger
7
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
8
9
def get_ckpt_dir(save_path, exp_name):
10
    return os.path.join(save_path, exp_name, "ckpts")
11
12
13
def get_ckpt_callback(save_path, exp_name):
14
    ckpt_dir = os.path.join(save_path, exp_name, "ckpts")
15
    return ModelCheckpoint(dirpath=ckpt_dir,
16
                           save_top_k=1,
17
                           verbose=True,
18
                           monitor='val_loss',
19
                           mode='min')
20
21
22
def get_early_stop_callback(patience=10):
23
    return EarlyStopping(monitor='val_loss',
24
                         patience=patience,
25
                         verbose=True,
26
                         mode='min')
27
28
29
def get_logger(logger_type, save_path, exp_name, project_name=None):
30
    if logger_type == 'wandb':
31
        if project_name is None:
32
            raise ValueError("Must supply project name when using wandb logger.")
33
        return WandbLogger(name=exp_name,
34
                           project=project_name)
35
    elif logger_type == 'test_tube': 
36
        exp_dir = os.path.join(save_path, exp_name)
37
        return TestTubeLogger(save_dir=exp_dir,
38
                              name='lightning_logs',
39
                              version="0")
40
    else:
41
        raise ValueError(f'{logger_type} is not a supported logger.')
42