--- a
+++ b/mmaction/models/backbones/amagi_slowfast.py
@@ -0,0 +1,556 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule, kaiming_init
+from mmcv.runner import _load_checkpoint, load_checkpoint
+from mmcv.utils import print_log
+
+from ...utils import get_root_logger
+from ..builder import BACKBONES
+from .resnet3d import ResNet3d
+
+try:
+    from mmdet.models import BACKBONES as MMDET_BACKBONES
+    mmdet_imported = True
+except (ImportError, ModuleNotFoundError):
+    mmdet_imported = False
+
+
+
+# AMAGI
+
+from ..segmentors.seg_hrnet_ocr import get_seg_model
+from ..segmentors.config.default import update_config
+from ..segmentors.config import config
+import yaml
+
+
+def load_seg_model(kfold=1):
+    #load model here
+    with open("mmaction/models/segmentors/config/ocr.yml", 'r') as f:
+        cfg = yaml.load(f)
+    #update_config(cfg, None)
+    if type(kfold) == str:
+        kfold = int(kfold)
+    seg_model = get_seg_model(cfg, kfold)
+
+    return seg_model
+
+class AMAGIPathway(ResNet3d):
+    """A pathway of Slowfast based on ResNet3d.
+
+    Args:
+        *args (arguments): Arguments same as :class:``ResNet3d``.
+        lateral (bool): Determines whether to enable the lateral connection
+            from another pathway. Default: False.
+        speed_ratio (int): Speed ratio indicating the ratio between time
+            dimension of the fast and slow pathway, corresponding to the
+            ``alpha`` in the paper. Default: 8.
+        channel_ratio (int): Reduce the channel number of fast pathway
+            by ``channel_ratio``, corresponding to ``beta`` in the paper.
+            Default: 8.
+        fusion_kernel (int): The kernel size of lateral fusion.
+            Default: 5.
+        **kwargs (keyword arguments): Keywords arguments for ResNet3d.
+    """
+
+    def __init__(self,
+                 *args,
+                 lateral=False,
+                 speed_ratio=8,
+                 channel_ratio=8,
+                 fusion_kernel=5,
+                 **kwargs):
+        self.lateral = lateral
+        self.speed_ratio = speed_ratio
+        self.channel_ratio = channel_ratio
+        self.fusion_kernel = fusion_kernel
+        super().__init__(*args, **kwargs)
+        self.inplanes = self.base_channels
+        if self.lateral:
+            self.conv1_lateral = ConvModule(
+                self.inplanes // self.channel_ratio,
+                # https://arxiv.org/abs/1812.03982, the
+                # third type of lateral connection has out_channel:
+                # 2 * \beta * C
+                self.inplanes * 2 // self.channel_ratio,
+                kernel_size=(fusion_kernel, 1, 1),
+                stride=(self.speed_ratio, 1, 1),
+                padding=((fusion_kernel - 1) // 2, 0, 0),
+                bias=False,
+                conv_cfg=self.conv_cfg,
+                norm_cfg=None,
+                act_cfg=None)
+
+        self.lateral_connections = []
+        for i in range(len(self.stage_blocks)):
+            planes = self.base_channels * 2**i
+            self.inplanes = planes * self.block.expansion
+
+            if lateral and i != self.num_stages - 1:
+                # no lateral connection needed in final stage
+                lateral_name = f'layer{(i + 1)}_lateral'
+                setattr(
+                    self, lateral_name,
+                    ConvModule(
+                        self.inplanes // self.channel_ratio,
+                        self.inplanes * 2 // self.channel_ratio,
+                        kernel_size=(fusion_kernel, 1, 1),
+                        stride=(self.speed_ratio, 1, 1),
+                        padding=((fusion_kernel - 1) // 2, 0, 0),
+                        bias=False,
+                        conv_cfg=self.conv_cfg,
+                        norm_cfg=None,
+                        act_cfg=None))
+                self.lateral_connections.append(lateral_name)
+
+    def make_res_layer(self,
+                       block,
+                       inplanes,
+                       planes,
+                       blocks,
+                       spatial_stride=1,
+                       temporal_stride=1,
+                       dilation=1,
+                       style='pytorch',
+                       inflate=1,
+                       inflate_style='3x1x1',
+                       non_local=0,
+                       non_local_cfg=dict(),
+                       conv_cfg=None,
+                       norm_cfg=None,
+                       act_cfg=None,
+                       with_cp=False):
+        """Build residual layer for Slowfast.
+
+        Args:
+            block (nn.Module): Residual module to be built.
+            inplanes (int): Number of channels for the input
+                feature in each block.
+            planes (int): Number of channels for the output
+                feature in each block.
+            blocks (int): Number of residual blocks.
+            spatial_stride (int | Sequence[int]): Spatial strides
+                in residual and conv layers. Default: 1.
+            temporal_stride (int | Sequence[int]): Temporal strides in
+                residual and conv layers. Default: 1.
+            dilation (int): Spacing between kernel elements. Default: 1.
+            style (str): ``pytorch`` or ``caffe``. If set to ``pytorch``,
+                the stride-two layer is the 3x3 conv layer,
+                otherwise the stride-two layer is the first 1x1 conv layer.
+                Default: ``pytorch``.
+            inflate (int | Sequence[int]): Determine whether to inflate
+                for each block. Default: 1.
+            inflate_style (str): ``3x1x1`` or ``3x3x3``. which determines
+                the kernel sizes and padding strides for conv1 and
+                conv2 in each block. Default: ``3x1x1``.
+            non_local (int | Sequence[int]): Determine whether to apply
+                non-local module in the corresponding block of each stages.
+                Default: 0.
+            non_local_cfg (dict): Config for non-local module.
+                Default: ``dict()``.
+            conv_cfg (dict | None): Config for conv layers. Default: None.
+            norm_cfg (dict | None): Config for norm layers. Default: None.
+            act_cfg (dict | None): Config for activate layers. Default: None.
+            with_cp (bool): Use checkpoint or not. Using checkpoint will save
+                some memory while slowing down the training speed.
+                Default: False.
+
+        Returns:
+            nn.Module: A residual layer for the given config.
+        """
+        inflate = inflate if not isinstance(inflate,
+                                            int) else (inflate, ) * blocks
+        non_local = non_local if not isinstance(
+            non_local, int) else (non_local, ) * blocks
+        assert len(inflate) == blocks and len(non_local) == blocks
+        if self.lateral:
+            lateral_inplanes = inplanes * 2 // self.channel_ratio
+        else:
+            lateral_inplanes = 0
+        if (spatial_stride != 1
+                or (inplanes + lateral_inplanes) != planes * block.expansion):
+            downsample = ConvModule(
+                inplanes + lateral_inplanes,
+                planes * block.expansion,
+                kernel_size=1,
+                stride=(temporal_stride, spatial_stride, spatial_stride),
+                bias=False,
+                conv_cfg=conv_cfg,
+                norm_cfg=norm_cfg,
+                act_cfg=None)
+        else:
+            downsample = None
+
+        layers = []
+        layers.append(
+            block(
+                inplanes + lateral_inplanes,
+                planes,
+                spatial_stride,
+                temporal_stride,
+                dilation,
+                downsample,
+                style=style,
+                inflate=(inflate[0] == 1),
+                inflate_style=inflate_style,
+                non_local=(non_local[0] == 1),
+                non_local_cfg=non_local_cfg,
+                conv_cfg=conv_cfg,
+                norm_cfg=norm_cfg,
+                act_cfg=act_cfg,
+                with_cp=with_cp))
+        inplanes = planes * block.expansion
+
+        for i in range(1, blocks):
+            layers.append(
+                block(
+                    inplanes,
+                    planes,
+                    1,
+                    1,
+                    dilation,
+                    style=style,
+                    inflate=(inflate[i] == 1),
+                    inflate_style=inflate_style,
+                    non_local=(non_local[i] == 1),
+                    non_local_cfg=non_local_cfg,
+                    conv_cfg=conv_cfg,
+                    norm_cfg=norm_cfg,
+                    act_cfg=act_cfg,
+                    with_cp=with_cp))
+
+        return nn.Sequential(*layers)
+
+    def inflate_weights(self, logger):
+        """Inflate the resnet2d parameters to resnet3d pathway.
+
+        The differences between resnet3d and resnet2d mainly lie in an extra
+        axis of conv kernel. To utilize the pretrained parameters in 2d model,
+        the weight of conv2d models should be inflated to fit in the shapes of
+        the 3d counterpart. For pathway the ``lateral_connection`` part should
+        not be inflated from 2d weights.
+
+        Args:
+            logger (logging.Logger): The logger used to print
+                debugging information.
+        """
+
+        state_dict_r2d = _load_checkpoint(self.pretrained)
+        if 'state_dict' in state_dict_r2d:
+            state_dict_r2d = state_dict_r2d['state_dict']
+
+        inflated_param_names = []
+        for name, module in self.named_modules():
+            if 'lateral' in name:
+                continue
+            if isinstance(module, ConvModule):
+                # we use a ConvModule to wrap conv+bn+relu layers, thus the
+                # name mapping is needed
+                if 'downsample' in name:
+                    # layer{X}.{Y}.downsample.conv->layer{X}.{Y}.downsample.0
+                    original_conv_name = name + '.0'
+                    # layer{X}.{Y}.downsample.bn->layer{X}.{Y}.downsample.1
+                    original_bn_name = name + '.1'
+                else:
+                    # layer{X}.{Y}.conv{n}.conv->layer{X}.{Y}.conv{n}
+                    original_conv_name = name
+                    # layer{X}.{Y}.conv{n}.bn->layer{X}.{Y}.bn{n}
+                    original_bn_name = name.replace('conv', 'bn')
+                if original_conv_name + '.weight' not in state_dict_r2d:
+                    logger.warning(f'Module not exist in the state_dict_r2d'
+                                   f': {original_conv_name}')
+                else:
+                    self._inflate_conv_params(module.conv, state_dict_r2d,
+                                              original_conv_name,
+                                              inflated_param_names)
+                if original_bn_name + '.weight' not in state_dict_r2d:
+                    logger.warning(f'Module not exist in the state_dict_r2d'
+                                   f': {original_bn_name}')
+                else:
+                    self._inflate_bn_params(module.bn, state_dict_r2d,
+                                            original_bn_name,
+                                            inflated_param_names)
+
+        # check if any parameters in the 2d checkpoint are not loaded
+        remaining_names = set(
+            state_dict_r2d.keys()) - set(inflated_param_names)
+        if remaining_names:
+            logger.info(f'These parameters in the 2d checkpoint are not loaded'
+                        f': {remaining_names}')
+
+    def _inflate_conv_params(self, conv3d, state_dict_2d, module_name_2d,
+                             inflated_param_names):
+        """Inflate a conv module from 2d to 3d.
+
+        The differences of conv modules betweene 2d and 3d in Pathway
+        mainly lie in the inplanes due to lateral connections. To fit the
+        shapes of the lateral connection counterpart, it will expand
+        parameters by concatting conv2d parameters and extra zero paddings.
+
+        Args:
+            conv3d (nn.Module): The destination conv3d module.
+            state_dict_2d (OrderedDict): The state dict of pretrained 2d model.
+            module_name_2d (str): The name of corresponding conv module in the
+                2d model.
+            inflated_param_names (list[str]): List of parameters that have been
+                inflated.
+        """
+        weight_2d_name = module_name_2d + '.weight'
+        conv2d_weight = state_dict_2d[weight_2d_name]
+        old_shape = conv2d_weight.shape
+        new_shape = conv3d.weight.data.shape
+        kernel_t = new_shape[2]
+
+        if new_shape[1] != old_shape[1]:
+            if new_shape[1] < old_shape[1]:
+                warnings.warn(f'The parameter of {module_name_2d} is not'
+                              'loaded due to incompatible shapes. ')
+                return
+            # Inplanes may be different due to lateral connections
+            new_channels = new_shape[1] - old_shape[1]
+            pad_shape = old_shape
+            pad_shape = pad_shape[:1] + (new_channels, ) + pad_shape[2:]
+            # Expand parameters by concat extra channels
+            conv2d_weight = torch.cat(
+                (conv2d_weight,
+                 torch.zeros(pad_shape).type_as(conv2d_weight).to(
+                     conv2d_weight.device)),
+                dim=1)
+
+        new_weight = conv2d_weight.data.unsqueeze(2).expand_as(
+            conv3d.weight) / kernel_t
+        conv3d.weight.data.copy_(new_weight)
+        inflated_param_names.append(weight_2d_name)
+
+        if getattr(conv3d, 'bias') is not None:
+            bias_2d_name = module_name_2d + '.bias'
+            conv3d.bias.data.copy_(state_dict_2d[bias_2d_name])
+            inflated_param_names.append(bias_2d_name)
+
+    def _freeze_stages(self):
+        """Prevent all the parameters from being optimized before
+        `self.frozen_stages`."""
+        if self.frozen_stages >= 0:
+            self.conv1.eval()
+            for param in self.conv1.parameters():
+                param.requires_grad = False
+
+        for i in range(1, self.frozen_stages + 1):
+            m = getattr(self, f'layer{i}')
+            m.eval()
+            for param in m.parameters():
+                param.requires_grad = False
+
+            if i != len(self.res_layers) and self.lateral:
+                # No fusion needed in the final stage
+                lateral_name = self.lateral_connections[i - 1]
+                conv_lateral = getattr(self, lateral_name)
+                conv_lateral.eval()
+                for param in conv_lateral.parameters():
+                    param.requires_grad = False
+
+    def init_weights(self, pretrained=None):
+        """Initiate the parameters either from existing checkpoint or from
+        scratch."""
+        if pretrained:
+            self.pretrained = pretrained
+
+        # Override the init_weights of i3d
+        super().init_weights()
+        for module_name in self.lateral_connections:
+            layer = getattr(self, module_name)
+            for m in layer.modules():
+                if isinstance(m, (nn.Conv3d, nn.Conv2d)):
+                    kaiming_init(m)
+
+
+pathway_cfg = {
+    'resnet3d': AMAGIPathway,
+    # TODO: BNInceptionPathway
+}
+
+
+def build_pathway(cfg, *args, **kwargs):
+    """Build pathway.
+
+    Args:
+        cfg (None or dict): cfg should contain:
+            - type (str): identify conv layer type.
+
+    Returns:
+        nn.Module: Created pathway.
+    """
+    if not (isinstance(cfg, dict) and 'type' in cfg):
+        raise TypeError('cfg must be a dict containing the key "type"')
+    cfg_ = cfg.copy()
+
+    pathway_type = cfg_.pop('type')
+    if pathway_type not in pathway_cfg:
+        raise KeyError(f'Unrecognized pathway type {pathway_type}')
+
+    pathway_cls = pathway_cfg[pathway_type]
+    pathway = pathway_cls(*args, **kwargs, **cfg_)
+
+    return pathway
+
+
+@BACKBONES.register_module()
+class AMAGI(nn.Module):
+    """Slowfast backbone.
+
+    This module is proposed in `SlowFast Networks for Video Recognition
+    <https://arxiv.org/abs/1812.03982>`_
+
+    Args:
+        pretrained (str): The file path to a pretrained model.
+        resample_rate (int): A large temporal stride ``resample_rate``
+            on input frames. The actual resample rate is calculated by
+            multipling the ``interval`` in ``SampleFrames`` in the
+            pipeline with ``resample_rate``, equivalent to the :math:`\\tau`
+            in the paper, i.e. it processes only one out of
+            ``resample_rate * interval`` frames. Default: 8.
+        speed_ratio (int): Speed ratio indicating the ratio between time
+            dimension of the fast and slow pathway, corresponding to the
+            :math:`\\alpha` in the paper. Default: 8.
+        channel_ratio (int): Reduce the channel number of fast pathway
+            by ``channel_ratio``, corresponding to :math:`\\beta` in the paper.
+            Default: 8.
+        slow_pathway (dict): Configuration of slow branch, should contain
+            necessary arguments for building the specific type of pathway
+            and:
+            type (str): type of backbone the pathway bases on.
+            lateral (bool): determine whether to build lateral connection
+            for the pathway.Default:
+
+            .. code-block:: Python
+
+                dict(type='ResNetPathway',
+                lateral=True, depth=50, pretrained=None,
+                conv1_kernel=(1, 7, 7), dilations=(1, 1, 1, 1),
+                conv1_stride_t=1, pool1_stride_t=1, inflate=(0, 0, 1, 1))
+
+        fast_pathway (dict): Configuration of fast branch, similar to
+            `slow_pathway`. Default:
+
+            .. code-block:: Python
+
+                dict(type='ResNetPathway',
+                lateral=False, depth=50, pretrained=None, base_channels=8,
+                conv1_kernel=(5, 7, 7), conv1_stride_t=1, pool1_stride_t=1)
+    """
+
+    def __init__(self,
+                 pretrained,
+                 resample_rate=8,
+                 speed_ratio=8,
+                 channel_ratio=8,
+                 kfold=1,
+                 slow_pathway=dict(
+                     type='resnet3d',
+                     depth=50,
+                     pretrained=None,
+                     lateral=True,
+                     conv1_kernel=(1, 7, 7),
+                     dilations=(1, 1, 1, 1),
+                     conv1_stride_t=1,
+                     pool1_stride_t=1,
+                     inflate=(0, 0, 1, 1)),
+                 fast_pathway=dict(
+                     type='resnet3d',
+                     depth=50,
+                     pretrained=None,
+                     lateral=False,
+                     base_channels=8,
+                     conv1_kernel=(5, 7, 7),
+                     conv1_stride_t=1,
+                     pool1_stride_t=1)):
+        super().__init__()
+        self.pretrained = pretrained
+        self.resample_rate = resample_rate
+        self.speed_ratio = speed_ratio
+        self.channel_ratio = channel_ratio
+
+        if slow_pathway['lateral']:
+            slow_pathway['speed_ratio'] = speed_ratio
+            slow_pathway['channel_ratio'] = channel_ratio
+
+        self.slow_path = build_pathway(slow_pathway)
+        self.fast_path = build_pathway(fast_pathway)
+        self.seg_model = load_seg_model(kfold)
+
+    def init_weights(self, pretrained=None):
+        """Initiate the parameters either from existing checkpoint or from
+        scratch."""
+        if pretrained:
+            self.pretrained = pretrained
+
+        if isinstance(self.pretrained, str):
+            logger = get_root_logger()
+            msg = f'load model from: {self.pretrained}'
+            print_log(msg, logger=logger)
+            # Directly load 3D model.
+            load_checkpoint(self, self.pretrained, strict=True, logger=logger)
+        elif self.pretrained is None:
+            # Init two branch separately.
+            self.fast_path.init_weights()
+            self.slow_path.init_weights()
+        else:
+            raise TypeError('pretrained must be a str or None')
+
+    def forward(self, x):
+        """Defines the computation performed at every call.
+
+        Args:
+            x (torch.Tensor): The input data.
+
+        Returns:
+            tuple[torch.Tensor]: The feature of the input samples extracted
+                by the backbone.
+        """
+        ###############
+        time_len = len(x[0,0])
+        seg_input = x[:,:,time_len//2]          # segmentation inference on center frame
+        seg_out = self.seg_model(seg_input)
+        ##############
+        x_slow = nn.functional.interpolate(
+            x,
+            mode='nearest',
+            scale_factor=(1.0 / self.resample_rate, 1.0, 1.0))
+        x_slow = self.slow_path.conv1(x_slow)
+        x_slow = self.slow_path.maxpool(x_slow)
+
+        x_fast = nn.functional.interpolate(
+            x,
+            mode='nearest',
+            scale_factor=(1.0 / (self.resample_rate // self.speed_ratio), 1.0,
+                          1.0))
+        x_fast = self.fast_path.conv1(x_fast)
+        x_fast = self.fast_path.maxpool(x_fast)
+
+        if self.slow_path.lateral:
+            x_fast_lateral = self.slow_path.conv1_lateral(x_fast)
+            x_slow = torch.cat((x_slow, x_fast_lateral), dim=1)
+
+        for i, layer_name in enumerate(self.slow_path.res_layers):
+            res_layer = getattr(self.slow_path, layer_name)
+            x_slow = res_layer(x_slow)
+            res_layer_fast = getattr(self.fast_path, layer_name)
+            x_fast = res_layer_fast(x_fast)
+            if (i != len(self.slow_path.res_layers) - 1
+                    and self.slow_path.lateral):
+                # No fusion needed in the final stage
+                lateral_name = self.slow_path.lateral_connections[i]
+                conv_lateral = getattr(self.slow_path, lateral_name)
+                x_fast_lateral = conv_lateral(x_fast)
+                x_slow = torch.cat((x_slow, x_fast_lateral), dim=1)
+
+        out = (x_slow, x_fast)
+
+        return out, seg_out
+
+
+if mmdet_imported:
+    MMDET_BACKBONES.register_module()(AMAGI)