a b/model_factory.py
1
import config
2
3
from data_set import DataLoader
4
from model import Convolution3DNetwork
5
6
import patient_loader as pl
7
import model_configuration as mc
8
9
10
class ModelFactory(object):
11
    def __init__(self, selected_model=None):
12
        self._selected_model = selected_model or config.SELECTED_MODEL
13
        self._with_augmentation = False
14
        self._init_model()
15
16
    def _init_model(self):
17
        if self._selected_model == config.BASELINE:
18
            self._image_loader = pl.MeanScansLoader()
19
            self._network_config = mc.BaselineConfig()
20
        elif self._selected_model == config.BASELINE_WITH_SEGMENTATION:
21
            self._image_loader = pl.SegmentedGaussianLungsLoader()
22
            self._network_config = mc.BaselineConfig()
23
        elif self._selected_model == config.NO_REGULARIZATION:
24
            self._image_loader = pl.SegmentedGaussianLungsLoader()
25
            self._network_config = mc.NoRegularizationConfig()
26
        elif self._selected_model == config.NO_REGULARIZATION_WATERSHED:
27
            # segmentation algorithm has already been selected and
28
            # changed during config setup
29
            self._image_loader = pl.SegmentedGaussianLungsLoader()
30
            self._network_config = mc.NoRegularizationConfig()
31
        elif self._selected_model == config.DROPOUT_L2NORM_REGULARIZARION:
32
            self._image_loader = pl.SegmentedGaussianLungsLoader()
33
            self._network_config = mc.DropoutsWithL2RegularizationConfig()
34
        elif self._selected_model == config.REGULARIZATION_MORE_SLICES:
35
            self._image_loader = pl.SegmentedLungsScansLoader()
36
            self._network_config = mc.DefaultConfig()
37
        elif self._selected_model == config.WITH_DATA_AUGMENTATION:
38
            self._image_loader = pl.SegmentedLungsScansLoader()
39
            self._with_augmentation = True
40
            self._network_config = mc.DefaultConfig()
41
        else: #default case
42
            self._image_loader = pl.SegmentedLungsScansLoader()
43
            self._with_augmentation = True
44
            self._network_config = mc.DefaultConfig()
45
46
    def get_network_model(self):
47
        return Convolution3DNetwork(config=self._network_config)
48
49
    def get_data_loader(self):
50
        return DataLoader(images_loader=self._image_loader,
51
                          add_transformed_positives=self._with_augmentation)