[d6d24a]: / Segmentation / tests / ml_tests / test_unet.py

Download this file

47 lines (35 with data), 1.5 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import tensorflow as tf
class UNetTest(tf.test.TestCase):
def testVanillaUNet(self):
import json
import os
config = 'config.json'
if os.path.isfile(config):
with open(config) as json_data_file:
data = json.load(json_data_file)
if data['train'] == 'false':
return
tf.random.set_seed(83922)
dataset_size = 10
batch_size = 5
input_shape = (256, 256, 1)
num_classes = 3
output_shape = (256, 256, num_classes)
num_filters = [64, 128, 256, 512, 1024]
features = tf.random.normal((dataset_size,) + input_shape)
labels = tf.random.normal((dataset_size,) + output_shape)
features, labels = self.evaluate([features, labels])
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
dataset = dataset.repeat().shuffle(dataset_size).batch(batch_size)
from Segmentation.model.unet import UNet # local file import
model = UNet(num_filters, num_classes)
from Segmentation.utils.losses import dice_coef_loss # local file import
model.compile('adam', loss=dice_coef_loss)
history = model.fit(dataset,
steps_per_epoch=dataset_size // batch_size,
epochs=2)
if __name__ == '__main__':
import sys
from os import getcwd
sys.path.insert(0, getcwd())
tf.test.main()