[2b78a8]: / src / BSNet / utils.py

Download this file

107 lines (81 with data), 2.6 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
""" Utility functions for segmentation models """
from functools import wraps
import numpy as np
from typing import Tuple, Any, Sequence
import tensorflow as tf
def get_layer_number(model, layer_name):
"""
Help find layer in Keras model by name
Args:
model: Keras `Model`
layer_name: str, name of layer
Returns:
index of layer
Raises:
ValueError: if model does not contains layer with such name
"""
for i, l in enumerate(model.layers):
if l.name == layer_name:
return i
raise ValueError('No layer with name {} in model {}.'.format(layer_name, model.name))
def extract_outputs(model, layers, include_top=False):
"""
Help extract intermediate layer outputs from model
Args:
model: Keras `Model`
layers: list of integers/str, list of layers indexes or names to extract output
include_top: bool, include final model layer output
Returns:
list of tensors (outputs)
"""
layers_indexes = ([get_layer_number(model, l) if isinstance(l, str) else l
for l in layers])
outputs = [model.layers[i].output for i in layers_indexes]
if include_top:
outputs.insert(0, model.output)
return outputs
def reverse(l):
"""Reverse list"""
return list(reversed(l))
# decorator for models aliases, to add doc string
def add_docstring(doc_string=None):
def decorator(fn):
if fn.__doc__:
fn.__doc__ += doc_string
else:
fn.__doc__ = doc_string
@wraps(fn)
def wrapper(*args, **kwargs):
return fn(*args, **kwargs)
return wrapper
return decorator
def recompile(model):
model.compile(model.optimizer, model.loss, model.metrics)
def freeze_model(model):
for layer in model.layers:
layer.trainable = False
return
def set_trainable(model):
for layer in model.layers:
layer.trainable = True
recompile(model)
def to_tuple(x):
if isinstance(x, tuple):
if len(x) == 2:
return x
elif np.isscalar(x):
return x, x
raise ValueError('Value should be tuple of length 2 or int value, got "{}"'.format(x))
def call_cascade(layers: Sequence[tf.keras.layers.Layer],
inp: Any, training: bool = True) -> Any:
"""
Calls a set of layers using the output as cascade.
Args:
layers: A sequence of layers
inp: input of the sequence
training: is training
"""
x = inp
for l in layers:
x = l(x, training=training)
return x