a b/networks/ClassWeightMult.py
1
import tensorflow as tf
2
from tensorflow.keras import Sequential
3
from tensorflow.keras.layers import Dense, Input, Concatenate, Maximum, Dropout, LeakyReLU
4
from tensorflow.keras.models import Model
5
6
7
class ClassWeightMult(tf.keras.layers.Layer):
8
    def __init__(self, class_weight):
9
        super().__init__()
10
        self.class_weight = class_weight
11
12
    def call(self, inputs):
13
        return inputs * self.class_weight