|
a |
|
b/lightning/util.py |
|
|
1 |
"""Define Logger class for logging information to stdout and disk.""" |
|
|
2 |
import json |
|
|
3 |
import os |
|
|
4 |
from os.path import join |
|
|
5 |
from pytorch_lightning.loggers import WandbLogger |
|
|
6 |
from pytorch_lightning.loggers.test_tube import TestTubeLogger |
|
|
7 |
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping |
|
|
8 |
|
|
|
9 |
def get_ckpt_dir(save_path, exp_name): |
|
|
10 |
return os.path.join(save_path, exp_name, "ckpts") |
|
|
11 |
|
|
|
12 |
|
|
|
13 |
def get_ckpt_callback(save_path, exp_name): |
|
|
14 |
ckpt_dir = os.path.join(save_path, exp_name, "ckpts") |
|
|
15 |
return ModelCheckpoint(dirpath=ckpt_dir, |
|
|
16 |
save_top_k=1, |
|
|
17 |
verbose=True, |
|
|
18 |
monitor='val_loss', |
|
|
19 |
mode='min') |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
def get_early_stop_callback(patience=10): |
|
|
23 |
return EarlyStopping(monitor='val_loss', |
|
|
24 |
patience=patience, |
|
|
25 |
verbose=True, |
|
|
26 |
mode='min') |
|
|
27 |
|
|
|
28 |
|
|
|
29 |
def get_logger(logger_type, save_path, exp_name, project_name=None): |
|
|
30 |
if logger_type == 'wandb': |
|
|
31 |
if project_name is None: |
|
|
32 |
raise ValueError("Must supply project name when using wandb logger.") |
|
|
33 |
return WandbLogger(name=exp_name, |
|
|
34 |
project=project_name) |
|
|
35 |
elif logger_type == 'test_tube': |
|
|
36 |
exp_dir = os.path.join(save_path, exp_name) |
|
|
37 |
return TestTubeLogger(save_dir=exp_dir, |
|
|
38 |
name='lightning_logs', |
|
|
39 |
version="0") |
|
|
40 |
else: |
|
|
41 |
raise ValueError(f'{logger_type} is not a supported logger.') |
|
|
42 |
|