Switch to side-by-side view

--- a
+++ b/EfficientNet_2d/EfficientNet_2d.py
@@ -0,0 +1,460 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+from EfficientNet_2d.utils import (
+    round_filters,
+    round_repeats,
+    drop_connect,
+    get_same_padding_conv2d,
+    get_model_params,
+    efficientnet_params,
+    load_pretrained_weights,
+    Swish,
+    MemoryEfficientSwish,
+)
+
+
+class MBConvBlock(nn.Module):
+    """
+    Mobile Inverted Residual Bottleneck Block
+
+    Args:
+        block_args (namedtuple): BlockArgs, see above
+        global_params (namedtuple): GlobalParam, see above
+
+    Attributes:
+        has_se (bool): Whether the block contains a Squeeze and Excitation layer.
+    """
+
+    def __init__(self, block_args, global_params):
+        super().__init__()
+        self._block_args = block_args
+        self._bn_mom = 1 - global_params.batch_norm_momentum
+        self._bn_eps = global_params.batch_norm_epsilon
+        self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
+        self.id_skip = block_args.id_skip  # skip connection and drop connect
+
+        # Get static or dynamic convolution depending on image size
+        Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
+
+        # Expansion phase
+        inp = self._block_args.input_filters  # number of input channels
+        oup = self._block_args.input_filters * self._block_args.expand_ratio  # number of output channels
+        if self._block_args.expand_ratio != 1:
+            self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
+            self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
+
+        # Depthwise convolution phase
+        k = self._block_args.kernel_size
+        s = self._block_args.stride
+        self._depthwise_conv = Conv2d(
+            in_channels=oup, out_channels=oup, groups=oup,  # groups makes it depthwise
+            kernel_size=k, stride=s, bias=False)
+        self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
+
+        # Squeeze and Excitation layer, if desired
+        if self.has_se:
+            num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
+            self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
+            self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
+
+        # Output phase
+        final_oup = self._block_args.output_filters
+        self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
+        self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
+        self._swish = MemoryEfficientSwish()
+
+    def forward(self, inputs, drop_connect_rate=None):
+        """
+        :param inputs: input tensor
+        :param drop_connect_rate: drop connect rate (float, between 0 and 1)
+        :return: output of block
+        """
+
+        # Expansion and Depthwise Convolution
+        x = inputs
+        if self._block_args.expand_ratio != 1:
+            x = self._swish(self._bn0(self._expand_conv(inputs)))
+        x = self._swish(self._bn1(self._depthwise_conv(x)))
+
+        # Squeeze and Excitation
+        if self.has_se:
+            x_squeezed = F.adaptive_avg_pool2d(x, 1)
+            x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed)))
+            x = torch.sigmoid(x_squeezed) * x
+
+        x = self._bn2(self._project_conv(x))
+
+        # Skip connection and drop connect
+        input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
+        if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
+            if drop_connect_rate:
+                x = drop_connect(x, p=drop_connect_rate, training=self.training)
+            x = x + inputs  # skip connection
+        return x
+
+    def set_swish(self, memory_efficient=True):
+        """Sets swish function as memory efficient (for training) or standard (for export)"""
+        self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
+
+
+class EfficientNet(nn.Module):
+    """
+    An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods
+
+    Args:
+        blocks_args (list): A list of BlockArgs to construct blocks
+        global_params (namedtuple): A set of GlobalParams shared between blocks
+
+    Example:
+        model = EfficientNet.from_pretrained('efficientnet-b0')
+
+    """
+    def __init__(self, blocks_args=None, global_params=None):
+        super().__init__()
+        assert isinstance(blocks_args, list), 'blocks_args should be a list'
+        assert len(blocks_args) > 0, 'block args must be greater than 0'
+        self._global_params = global_params
+        self._blocks_args = blocks_args
+
+        # Get static or dynamic convolution depending on image size
+        Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
+
+        # Batch norm parameters
+        bn_mom = 1 - self._global_params.batch_norm_momentum
+        bn_eps = self._global_params.batch_norm_epsilon
+
+        # Stem
+        in_channels = 3  # rgb
+        out_channels = round_filters(32, self._global_params)  # number of output channels
+        self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
+        self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
+
+        # Build blocks
+        self._blocks = nn.ModuleList([])
+        for block_args in self._blocks_args:
+
+            # Update block input and output filters based on depth multiplier.
+            block_args = block_args._replace(
+                input_filters=round_filters(block_args.input_filters, self._global_params),
+                output_filters=round_filters(block_args.output_filters, self._global_params),
+                num_repeat=round_repeats(block_args.num_repeat, self._global_params)
+            )
+
+            # The first block needs to take care of stride and filter size increase.
+            self._blocks.append(MBConvBlock(block_args, self._global_params))
+            if block_args.num_repeat > 1:
+                block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
+            for _ in range(block_args.num_repeat - 1):
+                self._blocks.append(MBConvBlock(block_args, self._global_params))
+
+        # Head
+        in_channels = block_args.output_filters  # output of final block
+        out_channels = round_filters(1280, self._global_params)
+        self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
+        self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
+
+        # Final linear layer
+        self._avg_pooling = nn.AdaptiveAvgPool2d(1)
+        self._dropout = nn.Dropout(self._global_params.dropout_rate)
+        self._fc = nn.Linear(out_channels, self._global_params.num_classes)
+        self._swish = MemoryEfficientSwish()
+
+    def set_swish(self, memory_efficient=True):
+        """Sets swish function as memory efficient (for training) or standard (for export)"""
+        self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
+        for block in self._blocks:
+            block.set_swish(memory_efficient)
+
+    def extract_features(self, inputs):
+        """ Returns output of the final convolution layer """
+        # Stem
+        x = self._swish(self._bn0(self._conv_stem(inputs)))
+
+        # Blocks
+        for idx, block in enumerate(self._blocks):
+            drop_connect_rate = self._global_params.drop_connect_rate
+            if drop_connect_rate:
+                drop_connect_rate *= float(idx) / len(self._blocks)
+            x = block(x, drop_connect_rate=drop_connect_rate)
+
+        # Head
+        x = self._swish(self._bn1(self._conv_head(x)))
+
+        return x
+
+    def forward(self, inputs):
+        """ Calls extract_features to extract features, applies final linear layer, and returns logits. """
+        bs = inputs.size(0)
+        # Convolution layers
+        x = self.extract_features(inputs)
+        # Pooling and final linear layer
+        x = self._avg_pooling(x)
+        x = x.view(bs, -1)
+        x = self._dropout(x)
+        x = self._fc(x)
+        return x
+
+    @classmethod
+    def from_name(cls, model_name, override_params=None):
+        cls._check_model_name_is_valid(model_name)
+        blocks_args, global_params = get_model_params(model_name, override_params)
+        return cls(blocks_args, global_params)
+
+    @classmethod
+    def from_pretrained(cls, model_name, advprop=False, num_classes=1000, in_channels=3):
+        model = cls.from_name(model_name, override_params={'num_classes': num_classes})
+        load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000), advprop=advprop)
+        if in_channels != 3:
+            Conv2d = get_same_padding_conv2d(image_size = model._global_params.image_size)
+            out_channels = round_filters(32, model._global_params)
+            model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
+        return model
+    
+    @classmethod
+    def get_image_size(cls, model_name):
+        cls._check_model_name_is_valid(model_name)
+        _, _, res, _ = efficientnet_params(model_name)
+        return res
+
+    @classmethod
+    def _check_model_name_is_valid(cls, model_name):
+        """ Validates model name. """ 
+        valid_models = ['efficientnet-b'+str(i) for i in range(9)]
+        if model_name not in valid_models:
+            raise ValueError('model_name should be one of: ' + ', '.join(valid_models))
+
+
+# get pretrained EfficientNet for k-classes classification
+def get_pretrained_EfficientNet(num_classes):
+    model = EfficientNet.from_pretrained('efficientnet-b0')
+    fc_features = model._fc.in_features
+    model._fc = nn.Linear(fc_features, num_classes)
+    return model
+
+
+class DAR_Effi(nn.Module):
+    def __init__(self, blocks_args=None, global_params=None, in_channels=3, att_start=11):
+        super(DAR_Effi, self).__init__()
+        assert isinstance(blocks_args, list), 'blocks_args should be a list'
+        assert len(blocks_args) > 0, 'block args must be greater than 0'
+        self._global_params = global_params
+        self._blocks_args = blocks_args
+        self.att_start = att_start  # for CA-module and NA-module
+
+        # Get static or dynamic convolution depending on image size
+        Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
+
+        # Batch norm parameters
+        bn_mom = 1 - self._global_params.batch_norm_momentum
+        bn_eps = self._global_params.batch_norm_epsilon
+
+        # Stem
+        out_channels = round_filters(32, self._global_params)  # number of output channels
+        self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
+        self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
+        self._conv_stem_cf = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
+        self._bn0_cf = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
+        self._conv_stem_lr = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
+        self._bn0_lr = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
+
+        # Build blocks of Prd-Net
+        self._blocks = nn.ModuleList([])
+        for block_args in self._blocks_args:
+
+            # Update block input and output filters based on depth multiplier.
+            block_args = block_args._replace(
+                input_filters=round_filters(block_args.input_filters, self._global_params),
+                output_filters=round_filters(block_args.output_filters, self._global_params),
+                num_repeat=round_repeats(block_args.num_repeat, self._global_params)
+            )
+
+            # The first block needs to take care of stride and filter size increase.
+            self._blocks.append(MBConvBlock(block_args, self._global_params))
+            if block_args.num_repeat > 1:
+                block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
+            for _ in range(block_args.num_repeat - 1):
+                self._blocks.append(MBConvBlock(block_args, self._global_params))
+
+        # Build blocks of CF-Net
+        self._blocks_cf = nn.ModuleList([])
+        for block_args in self._blocks_args:
+
+            # Update block input and output filters based on depth multiplier.
+            block_args = block_args._replace(
+                input_filters=round_filters(block_args.input_filters, self._global_params),
+                output_filters=round_filters(block_args.output_filters, self._global_params),
+                num_repeat=round_repeats(block_args.num_repeat, self._global_params)
+            )
+
+            # The first block needs to take care of stride and filter size increase.
+            self._blocks_cf.append(MBConvBlock(block_args, self._global_params))
+            if block_args.num_repeat > 1:
+                block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
+            for _ in range(block_args.num_repeat - 1):
+                self._blocks_cf.append(MBConvBlock(block_args, self._global_params))
+
+        # Build blocks of LR-Net
+        self._blocks_lr = nn.ModuleList([])
+        for block_args in self._blocks_args:
+
+            # Update block input and output filters based on depth multiplier.
+            block_args = block_args._replace(
+                input_filters=round_filters(block_args.input_filters, self._global_params),
+                output_filters=round_filters(block_args.output_filters, self._global_params),
+                num_repeat=round_repeats(block_args.num_repeat, self._global_params)
+            )
+
+            # The first block needs to take care of stride and filter size increase.
+            self._blocks_lr.append(MBConvBlock(block_args, self._global_params))
+            if block_args.num_repeat > 1:
+                block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
+            for _ in range(block_args.num_repeat - 1):
+                self._blocks_lr.append(MBConvBlock(block_args, self._global_params))
+
+        # Head
+        in_channels = block_args.output_filters  # output of final block
+        out_channels = round_filters(1280, self._global_params)
+        self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
+        self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
+        self._conv_head_cf = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
+        self._bn1_cf = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
+        self._conv_head_lr = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
+        self._bn1_lr = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
+
+        # Final linear layer
+        self._avg_pooling = nn.AdaptiveAvgPool2d(1)
+        self._dropout = nn.Dropout(self._global_params.dropout_rate)
+        self._fc = nn.Linear(out_channels, self._global_params.num_classes)
+        self._swish = MemoryEfficientSwish()
+
+        self._avg_pooling_cf = nn.AdaptiveAvgPool2d(1)
+        self._dropout_cf = nn.Dropout(self._global_params.dropout_rate)
+        self._fc_cf = nn.Linear(out_channels, self._global_params.num_classes)
+        self._swish_cf = MemoryEfficientSwish()
+
+        self._avg_pooling_lr = nn.AdaptiveAvgPool2d(1)
+        self._dropout_lr = nn.Dropout(self._global_params.dropout_rate)
+        self._fc_lr = nn.Linear(out_channels, self._global_params.num_classes)
+        self._swish_lr = MemoryEfficientSwish()
+
+    def set_swish(self, memory_efficient=True):
+        """Sets swish function as memory efficient (for training) or standard (for export)"""
+        self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
+        for block in self._blocks:
+            block.set_swish(memory_efficient)
+
+        self._swish_cf = MemoryEfficientSwish() if memory_efficient else Swish()
+        for block_cf in self._blocks_cf:
+            block_cf.set_swish(memory_efficient)
+
+        self._swish_lr = MemoryEfficientSwish() if memory_efficient else Swish()
+        for block_lr in self._blocks_lr:
+            block_lr.set_swish(memory_efficient)
+
+    def attention(self, f_prd, f_cf, f_lr):
+        w_cf = 1 - torch.sigmoid(f_cf)
+        add_cf = w_cf * f_prd
+
+        w_lr = 1 - abs(torch.sigmoid(f_prd)-torch.sigmoid(f_lr))
+        add_lr = w_lr * f_prd
+
+        f_prd = f_prd + add_cf + add_lr
+        return f_prd
+
+    def extract_features(self, inputs):
+        """ Returns output of the final convolution layer """
+
+        # Stem
+        x = self._swish(self._bn0(self._conv_stem(inputs)))
+        x_cf = self._swish_cf(self._bn0_cf(self._conv_stem_cf(inputs)))
+        x_lr = self._swish_lr(self._bn0_lr(self._conv_stem_lr(inputs)))
+
+        # Blocks
+        for idx, block in enumerate(self._blocks):
+            block_cf = self._blocks_cf[idx]
+            block_lr = self._blocks_lr[idx]
+
+            drop_connect_rate = self._global_params.drop_connect_rate
+            if drop_connect_rate:
+                drop_connect_rate *= float(idx) / len(self._blocks)
+
+            x = block(x, drop_connect_rate=drop_connect_rate)
+            x_cf = block_cf(x_cf, drop_connect_rate=drop_connect_rate)
+            x_lr = block_lr(x_lr, drop_connect_rate=drop_connect_rate)
+
+            if idx >= self.att_start:
+                x = self.attention(x, x_cf, x_lr)
+
+        # Head
+        x = self._swish(self._bn1(self._conv_head(x)))
+        x_cf = self._swish_cf(self._bn1_cf(self._conv_head_cf(x_cf)))
+        x_lr = self._swish_lr(self._bn1_lr(self._conv_head_lr(x_lr)))
+
+        return x, x_cf, x_lr
+
+    def forward(self, inputs):
+        bs = inputs.size(0)
+        # Convolution layers
+        x, x_cf, x_lr = self.extract_features(inputs)
+
+        # Pooling and final linear layer
+        x = self._avg_pooling(x)
+        x = x.view(bs, -1)
+        x = self._dropout(x)
+        x = self._fc(x)
+
+        x_cf = self._avg_pooling_cf(x_cf)
+        x_cf = x_cf.view(bs, -1)
+        x_cf = self._dropout_cf(x_cf)
+        x_cf = self._fc_cf(x_cf)
+
+        x_lr = self._avg_pooling_lr(x_lr)
+        x_lr = x_lr.view(bs, -1)
+        x_lr = self._dropout_lr(x_lr)
+        x_lr = self._fc_lr(x_lr)
+
+        return x, x_cf, x_lr
+
+    @classmethod
+    def from_name(cls, model_name, override_params=None, in_channels=3, att_start=11):
+        cls._check_model_name_is_valid(model_name)
+        blocks_args, global_params = get_model_params(model_name, override_params)
+        return cls(blocks_args, global_params, in_channels, att_start)
+
+    @classmethod
+    def get_image_size(cls, model_name):
+        cls._check_model_name_is_valid(model_name)
+        _, _, res, _ = efficientnet_params(model_name)
+        return res
+
+    @classmethod
+    def _check_model_name_is_valid(cls, model_name):
+        """ Validates model name. """
+        valid_models = ['efficientnet-b'+str(i) for i in range(9)]
+        if model_name not in valid_models:
+            raise ValueError('model_name should be one of: ' + ', '.join(valid_models))
+
+
+def get_pretrained_DAR(prd_params, cf_params, lr_params, num_classes):
+
+    dar_model = DAR_Effi.from_name('efficientnet-b0')
+    fc_features = dar_model._fc.in_features
+    dar_model._fc = nn.Linear(fc_features, num_classes)
+    dar_model._fc_cf = nn.Linear(fc_features, num_classes)
+    dar_model._fc_lr = nn.Linear(fc_features, num_classes)
+    dar_params = dar_model.state_dict()
+
+    for k, v in prd_params.items():
+        index_point = k.find('.')
+        k_apart = k[0:index_point]
+        k_bpart = k[index_point:len(k)]
+        k_cf = k_apart + '_cf' + k_bpart
+        k_lr = k_apart + '_lr' + k_bpart
+
+        dar_params[k] = prd_params[k]
+        dar_params[k_cf] = cf_params[k]
+        dar_params[k_lr] = lr_params[k]
+
+    dar_model.load_state_dict(dar_params)
+    return dar_model