--- a +++ b/src/experiment.py @@ -0,0 +1,145 @@ +from collections import OrderedDict +import torch +import torch.nn as nn +from torch.utils.data import ConcatDataset +import random +from catalyst.dl.experiment import ConfigExperiment +from dataset import * +from augmentation import train_aug, valid_aug + + +class Experiment(ConfigExperiment): + def _postprocess_model_for_stage(self, stage: str, model: nn.Module): + + import warnings + warnings.filterwarnings("ignore") + + random.seed(2411) + np.random.seed(2411) + torch.manual_seed(2411) + + model_ = model + if isinstance(model, torch.nn.DataParallel): + model_ = model_.module + + if stage == "warmup": + if hasattr(model_, 'freeze'): + model_.freeze(model_) + print("Freeze backbone model using freeze method !!!") + else: + for param in model_.parameters(): + param.requires_grad = False + + for param in model_.get_classifier().parameters(): + param.requires_grad = True + print("Freeze backbone model !!!") + + else: + if hasattr(model_, 'unfreeze'): + model_.unfreeze(model_) + print("Unfreeze backbone model using unfreeze method !!!") + else: + for param in model_.parameters(): + param.requires_grad = True + + print("Unfreeze backbone model !!!") + # + # import apex + # model_ = apex.parallel.convert_syncbn_model(model_) + + return model_ + + def get_datasets(self, stage: str, **kwargs): + datasets = OrderedDict() + + """ + image_key: 'id' + label_key: 'attribute_ids' + """ + + image_size = kwargs.get("image_size", [224, 224]) + train_csv = kwargs.get('train_csv', None) + valid_csv = kwargs.get('valid_csv', None) + with_any = kwargs.get('with_any', True) + dataset_type = kwargs.get('dataset_type', 'RSNADataset') + image_type = kwargs.get('image_type', 'jpg') + normalization = kwargs.get('normalization', True) + root = kwargs.get('root', None) + + print(f"Image Size: {image_size}") + + if train_csv: + transform = train_aug(image_size) + if dataset_type == 'RSNADataset': + train_set = RSNADataset( + csv_file=train_csv, + root=root, + with_any=with_any, + transform=transform, + mode='train', + image_type=image_type + ) + elif dataset_type == 'RSNAMultiWindowsDataset': + train_set = RSNAMultiWindowsDataset( + csv_file=train_csv, + root=root, + with_any=with_any, + transform=transform + ) + elif dataset_type == 'RSNADicomDataset': + train_set = RSNADicomDataset( + csv_file=train_csv, + root=root, + with_any=with_any, + transform=transform + ) + elif dataset_type == "RSNARandomWindowDataset": + train_set = RSNARandomWindowDataset( + csv_file=train_csv, + root=root, + with_any=with_any, + transform=transform + ) + else: + raise("No Dataset: {}".format(dataset_type)) + datasets["train"] = train_set + + if valid_csv: + transform = valid_aug(image_size) + if dataset_type == 'RSNADataset': + valid_set = RSNADataset( + csv_file=valid_csv, + root=root, + with_any=with_any, + transform=transform, + mode='valid', + image_type=image_type + ) + elif dataset_type == 'RSNAMultiWindowsDataset': + valid_set = RSNAMultiWindowsDataset( + csv_file=valid_csv, + root=root, + with_any=with_any, + transform=transform + ) + elif dataset_type == 'RSNADicomDataset': + valid_set = RSNADicomDataset( + csv_file=valid_csv, + root=root, + with_any=with_any, + transform=transform, + mode='valid' + ) + elif dataset_type == "RSNARandomWindowDataset": + valid_set = RSNARandomWindowDataset( + csv_file=valid_csv, + root=root, + with_any=with_any, + transform=transform, + mode='valid' + ) + else: + raise("No Dataset: {}".format(dataset_type)) + datasets["valid"] = valid_set + + return datasets