--- a +++ b/EfficientNet_2d/EfficientNet_2d.py @@ -0,0 +1,460 @@ +import torch +from torch import nn +from torch.nn import functional as F +from EfficientNet_2d.utils import ( + round_filters, + round_repeats, + drop_connect, + get_same_padding_conv2d, + get_model_params, + efficientnet_params, + load_pretrained_weights, + Swish, + MemoryEfficientSwish, +) + + +class MBConvBlock(nn.Module): + """ + Mobile Inverted Residual Bottleneck Block + + Args: + block_args (namedtuple): BlockArgs, see above + global_params (namedtuple): GlobalParam, see above + + Attributes: + has_se (bool): Whether the block contains a Squeeze and Excitation layer. + """ + + def __init__(self, block_args, global_params): + super().__init__() + self._block_args = block_args + self._bn_mom = 1 - global_params.batch_norm_momentum + self._bn_eps = global_params.batch_norm_epsilon + self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) + self.id_skip = block_args.id_skip # skip connection and drop connect + + # Get static or dynamic convolution depending on image size + Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) + + # Expansion phase + inp = self._block_args.input_filters # number of input channels + oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels + if self._block_args.expand_ratio != 1: + self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) + self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) + + # Depthwise convolution phase + k = self._block_args.kernel_size + s = self._block_args.stride + self._depthwise_conv = Conv2d( + in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise + kernel_size=k, stride=s, bias=False) + self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) + + # Squeeze and Excitation layer, if desired + if self.has_se: + num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) + self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) + self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) + + # Output phase + final_oup = self._block_args.output_filters + self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) + self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) + self._swish = MemoryEfficientSwish() + + def forward(self, inputs, drop_connect_rate=None): + """ + :param inputs: input tensor + :param drop_connect_rate: drop connect rate (float, between 0 and 1) + :return: output of block + """ + + # Expansion and Depthwise Convolution + x = inputs + if self._block_args.expand_ratio != 1: + x = self._swish(self._bn0(self._expand_conv(inputs))) + x = self._swish(self._bn1(self._depthwise_conv(x))) + + # Squeeze and Excitation + if self.has_se: + x_squeezed = F.adaptive_avg_pool2d(x, 1) + x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed))) + x = torch.sigmoid(x_squeezed) * x + + x = self._bn2(self._project_conv(x)) + + # Skip connection and drop connect + input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters + if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: + if drop_connect_rate: + x = drop_connect(x, p=drop_connect_rate, training=self.training) + x = x + inputs # skip connection + return x + + def set_swish(self, memory_efficient=True): + """Sets swish function as memory efficient (for training) or standard (for export)""" + self._swish = MemoryEfficientSwish() if memory_efficient else Swish() + + +class EfficientNet(nn.Module): + """ + An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods + + Args: + blocks_args (list): A list of BlockArgs to construct blocks + global_params (namedtuple): A set of GlobalParams shared between blocks + + Example: + model = EfficientNet.from_pretrained('efficientnet-b0') + + """ + def __init__(self, blocks_args=None, global_params=None): + super().__init__() + assert isinstance(blocks_args, list), 'blocks_args should be a list' + assert len(blocks_args) > 0, 'block args must be greater than 0' + self._global_params = global_params + self._blocks_args = blocks_args + + # Get static or dynamic convolution depending on image size + Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) + + # Batch norm parameters + bn_mom = 1 - self._global_params.batch_norm_momentum + bn_eps = self._global_params.batch_norm_epsilon + + # Stem + in_channels = 3 # rgb + out_channels = round_filters(32, self._global_params) # number of output channels + self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) + self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) + + # Build blocks + self._blocks = nn.ModuleList([]) + for block_args in self._blocks_args: + + # Update block input and output filters based on depth multiplier. + block_args = block_args._replace( + input_filters=round_filters(block_args.input_filters, self._global_params), + output_filters=round_filters(block_args.output_filters, self._global_params), + num_repeat=round_repeats(block_args.num_repeat, self._global_params) + ) + + # The first block needs to take care of stride and filter size increase. + self._blocks.append(MBConvBlock(block_args, self._global_params)) + if block_args.num_repeat > 1: + block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) + for _ in range(block_args.num_repeat - 1): + self._blocks.append(MBConvBlock(block_args, self._global_params)) + + # Head + in_channels = block_args.output_filters # output of final block + out_channels = round_filters(1280, self._global_params) + self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) + self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) + + # Final linear layer + self._avg_pooling = nn.AdaptiveAvgPool2d(1) + self._dropout = nn.Dropout(self._global_params.dropout_rate) + self._fc = nn.Linear(out_channels, self._global_params.num_classes) + self._swish = MemoryEfficientSwish() + + def set_swish(self, memory_efficient=True): + """Sets swish function as memory efficient (for training) or standard (for export)""" + self._swish = MemoryEfficientSwish() if memory_efficient else Swish() + for block in self._blocks: + block.set_swish(memory_efficient) + + def extract_features(self, inputs): + """ Returns output of the final convolution layer """ + # Stem + x = self._swish(self._bn0(self._conv_stem(inputs))) + + # Blocks + for idx, block in enumerate(self._blocks): + drop_connect_rate = self._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx) / len(self._blocks) + x = block(x, drop_connect_rate=drop_connect_rate) + + # Head + x = self._swish(self._bn1(self._conv_head(x))) + + return x + + def forward(self, inputs): + """ Calls extract_features to extract features, applies final linear layer, and returns logits. """ + bs = inputs.size(0) + # Convolution layers + x = self.extract_features(inputs) + # Pooling and final linear layer + x = self._avg_pooling(x) + x = x.view(bs, -1) + x = self._dropout(x) + x = self._fc(x) + return x + + @classmethod + def from_name(cls, model_name, override_params=None): + cls._check_model_name_is_valid(model_name) + blocks_args, global_params = get_model_params(model_name, override_params) + return cls(blocks_args, global_params) + + @classmethod + def from_pretrained(cls, model_name, advprop=False, num_classes=1000, in_channels=3): + model = cls.from_name(model_name, override_params={'num_classes': num_classes}) + load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000), advprop=advprop) + if in_channels != 3: + Conv2d = get_same_padding_conv2d(image_size = model._global_params.image_size) + out_channels = round_filters(32, model._global_params) + model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) + return model + + @classmethod + def get_image_size(cls, model_name): + cls._check_model_name_is_valid(model_name) + _, _, res, _ = efficientnet_params(model_name) + return res + + @classmethod + def _check_model_name_is_valid(cls, model_name): + """ Validates model name. """ + valid_models = ['efficientnet-b'+str(i) for i in range(9)] + if model_name not in valid_models: + raise ValueError('model_name should be one of: ' + ', '.join(valid_models)) + + +# get pretrained EfficientNet for k-classes classification +def get_pretrained_EfficientNet(num_classes): + model = EfficientNet.from_pretrained('efficientnet-b0') + fc_features = model._fc.in_features + model._fc = nn.Linear(fc_features, num_classes) + return model + + +class DAR_Effi(nn.Module): + def __init__(self, blocks_args=None, global_params=None, in_channels=3, att_start=11): + super(DAR_Effi, self).__init__() + assert isinstance(blocks_args, list), 'blocks_args should be a list' + assert len(blocks_args) > 0, 'block args must be greater than 0' + self._global_params = global_params + self._blocks_args = blocks_args + self.att_start = att_start # for CA-module and NA-module + + # Get static or dynamic convolution depending on image size + Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) + + # Batch norm parameters + bn_mom = 1 - self._global_params.batch_norm_momentum + bn_eps = self._global_params.batch_norm_epsilon + + # Stem + out_channels = round_filters(32, self._global_params) # number of output channels + self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) + self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) + self._conv_stem_cf = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) + self._bn0_cf = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) + self._conv_stem_lr = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) + self._bn0_lr = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) + + # Build blocks of Prd-Net + self._blocks = nn.ModuleList([]) + for block_args in self._blocks_args: + + # Update block input and output filters based on depth multiplier. + block_args = block_args._replace( + input_filters=round_filters(block_args.input_filters, self._global_params), + output_filters=round_filters(block_args.output_filters, self._global_params), + num_repeat=round_repeats(block_args.num_repeat, self._global_params) + ) + + # The first block needs to take care of stride and filter size increase. + self._blocks.append(MBConvBlock(block_args, self._global_params)) + if block_args.num_repeat > 1: + block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) + for _ in range(block_args.num_repeat - 1): + self._blocks.append(MBConvBlock(block_args, self._global_params)) + + # Build blocks of CF-Net + self._blocks_cf = nn.ModuleList([]) + for block_args in self._blocks_args: + + # Update block input and output filters based on depth multiplier. + block_args = block_args._replace( + input_filters=round_filters(block_args.input_filters, self._global_params), + output_filters=round_filters(block_args.output_filters, self._global_params), + num_repeat=round_repeats(block_args.num_repeat, self._global_params) + ) + + # The first block needs to take care of stride and filter size increase. + self._blocks_cf.append(MBConvBlock(block_args, self._global_params)) + if block_args.num_repeat > 1: + block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) + for _ in range(block_args.num_repeat - 1): + self._blocks_cf.append(MBConvBlock(block_args, self._global_params)) + + # Build blocks of LR-Net + self._blocks_lr = nn.ModuleList([]) + for block_args in self._blocks_args: + + # Update block input and output filters based on depth multiplier. + block_args = block_args._replace( + input_filters=round_filters(block_args.input_filters, self._global_params), + output_filters=round_filters(block_args.output_filters, self._global_params), + num_repeat=round_repeats(block_args.num_repeat, self._global_params) + ) + + # The first block needs to take care of stride and filter size increase. + self._blocks_lr.append(MBConvBlock(block_args, self._global_params)) + if block_args.num_repeat > 1: + block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) + for _ in range(block_args.num_repeat - 1): + self._blocks_lr.append(MBConvBlock(block_args, self._global_params)) + + # Head + in_channels = block_args.output_filters # output of final block + out_channels = round_filters(1280, self._global_params) + self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) + self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) + self._conv_head_cf = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) + self._bn1_cf = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) + self._conv_head_lr = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) + self._bn1_lr = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) + + # Final linear layer + self._avg_pooling = nn.AdaptiveAvgPool2d(1) + self._dropout = nn.Dropout(self._global_params.dropout_rate) + self._fc = nn.Linear(out_channels, self._global_params.num_classes) + self._swish = MemoryEfficientSwish() + + self._avg_pooling_cf = nn.AdaptiveAvgPool2d(1) + self._dropout_cf = nn.Dropout(self._global_params.dropout_rate) + self._fc_cf = nn.Linear(out_channels, self._global_params.num_classes) + self._swish_cf = MemoryEfficientSwish() + + self._avg_pooling_lr = nn.AdaptiveAvgPool2d(1) + self._dropout_lr = nn.Dropout(self._global_params.dropout_rate) + self._fc_lr = nn.Linear(out_channels, self._global_params.num_classes) + self._swish_lr = MemoryEfficientSwish() + + def set_swish(self, memory_efficient=True): + """Sets swish function as memory efficient (for training) or standard (for export)""" + self._swish = MemoryEfficientSwish() if memory_efficient else Swish() + for block in self._blocks: + block.set_swish(memory_efficient) + + self._swish_cf = MemoryEfficientSwish() if memory_efficient else Swish() + for block_cf in self._blocks_cf: + block_cf.set_swish(memory_efficient) + + self._swish_lr = MemoryEfficientSwish() if memory_efficient else Swish() + for block_lr in self._blocks_lr: + block_lr.set_swish(memory_efficient) + + def attention(self, f_prd, f_cf, f_lr): + w_cf = 1 - torch.sigmoid(f_cf) + add_cf = w_cf * f_prd + + w_lr = 1 - abs(torch.sigmoid(f_prd)-torch.sigmoid(f_lr)) + add_lr = w_lr * f_prd + + f_prd = f_prd + add_cf + add_lr + return f_prd + + def extract_features(self, inputs): + """ Returns output of the final convolution layer """ + + # Stem + x = self._swish(self._bn0(self._conv_stem(inputs))) + x_cf = self._swish_cf(self._bn0_cf(self._conv_stem_cf(inputs))) + x_lr = self._swish_lr(self._bn0_lr(self._conv_stem_lr(inputs))) + + # Blocks + for idx, block in enumerate(self._blocks): + block_cf = self._blocks_cf[idx] + block_lr = self._blocks_lr[idx] + + drop_connect_rate = self._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx) / len(self._blocks) + + x = block(x, drop_connect_rate=drop_connect_rate) + x_cf = block_cf(x_cf, drop_connect_rate=drop_connect_rate) + x_lr = block_lr(x_lr, drop_connect_rate=drop_connect_rate) + + if idx >= self.att_start: + x = self.attention(x, x_cf, x_lr) + + # Head + x = self._swish(self._bn1(self._conv_head(x))) + x_cf = self._swish_cf(self._bn1_cf(self._conv_head_cf(x_cf))) + x_lr = self._swish_lr(self._bn1_lr(self._conv_head_lr(x_lr))) + + return x, x_cf, x_lr + + def forward(self, inputs): + bs = inputs.size(0) + # Convolution layers + x, x_cf, x_lr = self.extract_features(inputs) + + # Pooling and final linear layer + x = self._avg_pooling(x) + x = x.view(bs, -1) + x = self._dropout(x) + x = self._fc(x) + + x_cf = self._avg_pooling_cf(x_cf) + x_cf = x_cf.view(bs, -1) + x_cf = self._dropout_cf(x_cf) + x_cf = self._fc_cf(x_cf) + + x_lr = self._avg_pooling_lr(x_lr) + x_lr = x_lr.view(bs, -1) + x_lr = self._dropout_lr(x_lr) + x_lr = self._fc_lr(x_lr) + + return x, x_cf, x_lr + + @classmethod + def from_name(cls, model_name, override_params=None, in_channels=3, att_start=11): + cls._check_model_name_is_valid(model_name) + blocks_args, global_params = get_model_params(model_name, override_params) + return cls(blocks_args, global_params, in_channels, att_start) + + @classmethod + def get_image_size(cls, model_name): + cls._check_model_name_is_valid(model_name) + _, _, res, _ = efficientnet_params(model_name) + return res + + @classmethod + def _check_model_name_is_valid(cls, model_name): + """ Validates model name. """ + valid_models = ['efficientnet-b'+str(i) for i in range(9)] + if model_name not in valid_models: + raise ValueError('model_name should be one of: ' + ', '.join(valid_models)) + + +def get_pretrained_DAR(prd_params, cf_params, lr_params, num_classes): + + dar_model = DAR_Effi.from_name('efficientnet-b0') + fc_features = dar_model._fc.in_features + dar_model._fc = nn.Linear(fc_features, num_classes) + dar_model._fc_cf = nn.Linear(fc_features, num_classes) + dar_model._fc_lr = nn.Linear(fc_features, num_classes) + dar_params = dar_model.state_dict() + + for k, v in prd_params.items(): + index_point = k.find('.') + k_apart = k[0:index_point] + k_bpart = k[index_point:len(k)] + k_cf = k_apart + '_cf' + k_bpart + k_lr = k_apart + '_lr' + k_bpart + + dar_params[k] = prd_params[k] + dar_params[k_cf] = cf_params[k] + dar_params[k_lr] = lr_params[k] + + dar_model.load_state_dict(dar_params) + return dar_model