[b48499]: / torch_ecg / components / nas.py

Download this file

109 lines (94 with data), 3.1 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
"""
neural architecture search
TODO: replace this module with Ray Tune
"""
from typing import Sequence
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel as DP
from torch.utils.data.dataset import Dataset
from ..cfg import CFG
from .trainer import BaseTrainer
__all__ = [
"NAS",
]
class NAS:
""" """
__name__ = "NAS"
def __init__(
self,
trainer_cls: BaseTrainer,
model_cls: nn.Module,
dataset_cls: Dataset,
train_config: dict,
model_configs: Sequence[dict],
lazy: bool = False,
) -> None:
"""
Parameters
----------
trainer_cls: BaseTrainer,
trainer class
model_cls: nn.Module,
model class
dataset_cls: Dataset,
dataset class
train_config: dict,
train configurations
model_configs: sequence of dict,
model configurations, each with a different network architecture
lazy: bool, default False,
whether to load the dataset in the trainer at initialization
"""
self.trainer_cls = trainer_cls
self.model_cls = model_cls
self.dataset_cls = dataset_cls
self.train_config = CFG(train_config)
self.model_configs = model_configs
self.lazy = lazy
if not lazy:
self.ds_train = self.dataset_cls(self.train_config, training=True, lazy=False)
self.ds_val = self.dataset_cls(self.train_config, training=False, lazy=False)
else:
self.ds_train = None
self.ds_val = None
def search(self) -> None:
""" """
if self.ds_train is None or self.ds_val is None:
raise ValueError("training dataset or validation dataset is not set")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for model_config in self.model_configs:
model = self.model_cls(
classes=self.train_config.classes,
n_leads=self.train_config.n_leads,
config=model_config,
)
if torch.cuda.device_count() > 1:
model = DP(model)
# model = DDP(model)
model.to(device=device)
model.train()
trainer = self.trainer_cls(
model=model,
dataset_cls=self.dataset_cls,
train_config=self.train_config,
model_config=model_config,
device=device,
lazy=True,
)
trainer._setup_dataloaders(self.ds_train, self.ds_val)
trainer.train()
del model
del trainer
torch.cuda.empty_cache()
def _setup_dataset(self, ds_train: Dataset, ds_val: Dataset) -> None:
"""
Parameters
----------
ds_train: Dataset,
training dataset
ds_val: Dataset,
validation dataset
"""
self.ds_train = ds_train
self.ds_val = ds_val