[139527]: / lightning / util.py

Download this file

43 lines (34 with data), 1.5 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""Define Logger class for logging information to stdout and disk."""
import json
import os
from os.path import join
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.loggers.test_tube import TestTubeLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
def get_ckpt_dir(save_path, exp_name):
return os.path.join(save_path, exp_name, "ckpts")
def get_ckpt_callback(save_path, exp_name):
ckpt_dir = os.path.join(save_path, exp_name, "ckpts")
return ModelCheckpoint(dirpath=ckpt_dir,
save_top_k=1,
verbose=True,
monitor='val_loss',
mode='min')
def get_early_stop_callback(patience=10):
return EarlyStopping(monitor='val_loss',
patience=patience,
verbose=True,
mode='min')
def get_logger(logger_type, save_path, exp_name, project_name=None):
if logger_type == 'wandb':
if project_name is None:
raise ValueError("Must supply project name when using wandb logger.")
return WandbLogger(name=exp_name,
project=project_name)
elif logger_type == 'test_tube':
exp_dir = os.path.join(save_path, exp_name)
return TestTubeLogger(save_dir=exp_dir,
name='lightning_logs',
version="0")
else:
raise ValueError(f'{logger_type} is not a supported logger.')