Diff of /src/experiment.py [000000] .. [95f789]

Switch to unified view

a b/src/experiment.py
1
from collections import OrderedDict
2
import torch
3
import torch.nn as nn
4
from torch.utils.data import ConcatDataset
5
import random
6
from catalyst.dl.experiment import ConfigExperiment
7
from dataset import *
8
from augmentation import train_aug, valid_aug
9
10
11
class Experiment(ConfigExperiment):
12
    def _postprocess_model_for_stage(self, stage: str, model: nn.Module):
13
14
        import warnings
15
        warnings.filterwarnings("ignore")
16
17
        random.seed(2411)
18
        np.random.seed(2411)
19
        torch.manual_seed(2411)
20
21
        model_ = model
22
        if isinstance(model, torch.nn.DataParallel):
23
            model_ = model_.module
24
25
        if stage == "warmup":
26
            if hasattr(model_, 'freeze'):
27
                model_.freeze(model_)
28
                print("Freeze backbone model using freeze method !!!")
29
            else:
30
                for param in model_.parameters():
31
                    param.requires_grad = False
32
33
                for param in model_.get_classifier().parameters():
34
                    param.requires_grad = True
35
                print("Freeze backbone model !!!")
36
37
        else:
38
            if hasattr(model_, 'unfreeze'):
39
                model_.unfreeze(model_)
40
                print("Unfreeze backbone model using unfreeze method !!!")
41
            else:
42
                for param in model_.parameters():
43
                    param.requires_grad = True
44
45
                print("Unfreeze backbone model !!!")
46
        #
47
        # import apex
48
        # model_ = apex.parallel.convert_syncbn_model(model_)
49
50
        return model_
51
52
    def get_datasets(self, stage: str, **kwargs):
53
        datasets = OrderedDict()
54
55
        """
56
        image_key: 'id'
57
        label_key: 'attribute_ids'
58
        """
59
60
        image_size = kwargs.get("image_size", [224, 224])
61
        train_csv = kwargs.get('train_csv', None)
62
        valid_csv = kwargs.get('valid_csv', None)
63
        with_any = kwargs.get('with_any', True)
64
        dataset_type = kwargs.get('dataset_type', 'RSNADataset')
65
        image_type = kwargs.get('image_type', 'jpg')
66
        normalization = kwargs.get('normalization', True)
67
        root = kwargs.get('root', None)
68
69
        print(f"Image Size: {image_size}")
70
71
        if train_csv:
72
            transform = train_aug(image_size)
73
            if dataset_type == 'RSNADataset':
74
                train_set = RSNADataset(
75
                    csv_file=train_csv,
76
                    root=root,
77
                    with_any=with_any,
78
                    transform=transform,
79
                    mode='train',
80
                    image_type=image_type
81
                )
82
            elif dataset_type == 'RSNAMultiWindowsDataset':
83
                train_set = RSNAMultiWindowsDataset(
84
                    csv_file=train_csv,
85
                    root=root,
86
                    with_any=with_any,
87
                    transform=transform
88
                )
89
            elif dataset_type == 'RSNADicomDataset':
90
                train_set = RSNADicomDataset(
91
                    csv_file=train_csv,
92
                    root=root,
93
                    with_any=with_any,
94
                    transform=transform
95
                )
96
            elif dataset_type == "RSNARandomWindowDataset":
97
                train_set = RSNARandomWindowDataset(
98
                    csv_file=train_csv,
99
                    root=root,
100
                    with_any=with_any,
101
                    transform=transform
102
                )
103
            else:
104
                raise("No Dataset: {}".format(dataset_type))
105
            datasets["train"] = train_set
106
107
        if valid_csv:
108
            transform = valid_aug(image_size)
109
            if dataset_type == 'RSNADataset':
110
                valid_set = RSNADataset(
111
                    csv_file=valid_csv,
112
                    root=root,
113
                    with_any=with_any,
114
                    transform=transform,
115
                    mode='valid',
116
                    image_type=image_type
117
                )
118
            elif dataset_type == 'RSNAMultiWindowsDataset':
119
                valid_set = RSNAMultiWindowsDataset(
120
                    csv_file=valid_csv,
121
                    root=root,
122
                    with_any=with_any,
123
                    transform=transform
124
                )
125
            elif dataset_type == 'RSNADicomDataset':
126
                valid_set = RSNADicomDataset(
127
                    csv_file=valid_csv,
128
                    root=root,
129
                    with_any=with_any,
130
                    transform=transform,
131
                    mode='valid'
132
                )
133
            elif dataset_type == "RSNARandomWindowDataset":
134
                valid_set = RSNARandomWindowDataset(
135
                    csv_file=valid_csv,
136
                    root=root,
137
                    with_any=with_any,
138
                    transform=transform,
139
                    mode='valid'
140
                )
141
            else:
142
                raise("No Dataset: {}".format(dataset_type))
143
            datasets["valid"] = valid_set
144
145
        return datasets