|
a |
|
b/model_factory.py |
|
|
1 |
import config |
|
|
2 |
|
|
|
3 |
from data_set import DataLoader |
|
|
4 |
from model import Convolution3DNetwork |
|
|
5 |
|
|
|
6 |
import patient_loader as pl |
|
|
7 |
import model_configuration as mc |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
class ModelFactory(object): |
|
|
11 |
def __init__(self, selected_model=None): |
|
|
12 |
self._selected_model = selected_model or config.SELECTED_MODEL |
|
|
13 |
self._with_augmentation = False |
|
|
14 |
self._init_model() |
|
|
15 |
|
|
|
16 |
def _init_model(self): |
|
|
17 |
if self._selected_model == config.BASELINE: |
|
|
18 |
self._image_loader = pl.MeanScansLoader() |
|
|
19 |
self._network_config = mc.BaselineConfig() |
|
|
20 |
elif self._selected_model == config.BASELINE_WITH_SEGMENTATION: |
|
|
21 |
self._image_loader = pl.SegmentedGaussianLungsLoader() |
|
|
22 |
self._network_config = mc.BaselineConfig() |
|
|
23 |
elif self._selected_model == config.NO_REGULARIZATION: |
|
|
24 |
self._image_loader = pl.SegmentedGaussianLungsLoader() |
|
|
25 |
self._network_config = mc.NoRegularizationConfig() |
|
|
26 |
elif self._selected_model == config.NO_REGULARIZATION_WATERSHED: |
|
|
27 |
# segmentation algorithm has already been selected and |
|
|
28 |
# changed during config setup |
|
|
29 |
self._image_loader = pl.SegmentedGaussianLungsLoader() |
|
|
30 |
self._network_config = mc.NoRegularizationConfig() |
|
|
31 |
elif self._selected_model == config.DROPOUT_L2NORM_REGULARIZARION: |
|
|
32 |
self._image_loader = pl.SegmentedGaussianLungsLoader() |
|
|
33 |
self._network_config = mc.DropoutsWithL2RegularizationConfig() |
|
|
34 |
elif self._selected_model == config.REGULARIZATION_MORE_SLICES: |
|
|
35 |
self._image_loader = pl.SegmentedLungsScansLoader() |
|
|
36 |
self._network_config = mc.DefaultConfig() |
|
|
37 |
elif self._selected_model == config.WITH_DATA_AUGMENTATION: |
|
|
38 |
self._image_loader = pl.SegmentedLungsScansLoader() |
|
|
39 |
self._with_augmentation = True |
|
|
40 |
self._network_config = mc.DefaultConfig() |
|
|
41 |
else: #default case |
|
|
42 |
self._image_loader = pl.SegmentedLungsScansLoader() |
|
|
43 |
self._with_augmentation = True |
|
|
44 |
self._network_config = mc.DefaultConfig() |
|
|
45 |
|
|
|
46 |
def get_network_model(self): |
|
|
47 |
return Convolution3DNetwork(config=self._network_config) |
|
|
48 |
|
|
|
49 |
def get_data_loader(self): |
|
|
50 |
return DataLoader(images_loader=self._image_loader, |
|
|
51 |
add_transformed_positives=self._with_augmentation) |