[b48499]: / test / test_components / test_loggers.py

Download this file

104 lines (73 with data), 2.4 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
"""
"""
from pathlib import Path
import pytest
import torch
from torch_ecg.components.loggers import BaseLogger, CSVLogger, LoggerManager, TensorBoardXLogger, TxtLogger
_LOG_DIR = Path(__file__).parents[1] / "logs"
def test_base_logger():
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
BaseLogger()
def test_logger_manager():
config = {
"log_dir": _LOG_DIR,
"log_suffix": "test",
"txt_logger": True,
"csv_logger": True,
"tensorboardx_logger": True,
}
lm = LoggerManager.from_config(config)
assert lm.log_dir == _LOG_DIR
assert lm.log_suffix == "test"
assert str(lm) == repr(lm)
# with pytest.raises(NotImplementedError):
# lm._add_wandb_logger() # not implemented yet
lm.log_message("test")
lm.log_metrics({"test": torch.scalar_tensor(1.1)})
lm.flush()
lm.close()
def test_txt_logger():
config = {
"log_dir": _LOG_DIR,
"log_suffix": "test",
}
logger = TxtLogger.from_config(config)
assert logger.log_dir == _LOG_DIR
assert str(logger) == repr(logger)
logger.log_message("test")
logger.log_metrics({"test": torch.scalar_tensor(1.1)})
logger.flush()
logger.close()
assert logger.filename == str(_LOG_DIR / logger.log_file)
assert Path(logger.filename).exists()
assert str(logger) == repr(logger)
def test_csv_logger():
config = {
"log_dir": _LOG_DIR,
"log_suffix": "test",
}
logger = CSVLogger.from_config(config)
assert logger.log_dir == _LOG_DIR
assert str(logger) == repr(logger)
logger.log_message("test")
logger.log_metrics({"test": torch.scalar_tensor(1.1)})
logger.flush()
logger.close()
assert logger.filename == str(_LOG_DIR / logger.log_file)
assert Path(logger.filename).exists()
assert str(logger) == repr(logger)
def test_tensorboardx_logger():
config = {
"log_dir": _LOG_DIR,
"log_suffix": "test",
}
logger = TensorBoardXLogger.from_config(config)
assert logger.log_dir == _LOG_DIR
assert str(logger) == repr(logger)
logger.log_message("test")
logger.log_metrics({"test": torch.scalar_tensor(1.1)})
logger.flush()
logger.close()
assert logger.filename == str(_LOG_DIR / logger.log_file)
assert Path(logger.filename).exists()
assert str(logger) == repr(logger)