[11ca2d]: / networks / ClassWeightMult.py

Download this file

13 lines (10 with data), 421 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Input, Concatenate, Maximum, Dropout, LeakyReLU
from tensorflow.keras.models import Model
class ClassWeightMult(tf.keras.layers.Layer):
def __init__(self, class_weight):
super().__init__()
self.class_weight = class_weight
def call(self, inputs):
return inputs * self.class_weight