--- a
+++ b/demo/visualize_heatmap_volume.ipynb
@@ -0,0 +1,403 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "id": "speaking-algebra",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import os\n",
+    "import cv2\n",
+    "import os.path as osp\n",
+    "import decord\n",
+    "import numpy as np\n",
+    "import matplotlib.pyplot as plt\n",
+    "import urllib\n",
+    "import moviepy.editor as mpy\n",
+    "import random as rd\n",
+    "from mmpose.apis import vis_pose_result\n",
+    "from mmpose.models import TopDown\n",
+    "from mmcv import load, dump\n",
+    "\n",
+    "# We assume the annotation is already prepared\n",
+    "gym_train_ann_file = '../data/skeleton/gym_train.pkl'\n",
+    "gym_val_ann_file = '../data/skeleton/gym_val.pkl'\n",
+    "ntu60_xsub_train_ann_file = '../data/skeleton/ntu60_xsub_train.pkl'\n",
+    "ntu60_xsub_val_ann_file = '../data/skeleton/ntu60_xsub_val.pkl'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "id": "alive-consolidation",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "FONTFACE = cv2.FONT_HERSHEY_DUPLEX\n",
+    "FONTSCALE = 0.6\n",
+    "FONTCOLOR = (255, 255, 255)\n",
+    "BGBLUE = (0, 119, 182)\n",
+    "THICKNESS = 1\n",
+    "LINETYPE = 1"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "id": "ranging-conjunction",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def add_label(frame, label, BGCOLOR=BGBLUE):\n",
+    "    threshold = 30\n",
+    "    def split_label(label):\n",
+    "        label = label.split()\n",
+    "        lines, cline = [], ''\n",
+    "        for word in label:\n",
+    "            if len(cline) + len(word) < threshold:\n",
+    "                cline = cline + ' ' + word\n",
+    "            else:\n",
+    "                lines.append(cline)\n",
+    "                cline = word\n",
+    "        if cline != '':\n",
+    "            lines += [cline]\n",
+    "        return lines\n",
+    "    \n",
+    "    if len(label) > 30:\n",
+    "        label = split_label(label)\n",
+    "    else:\n",
+    "        label = [label]\n",
+    "    label = ['Action: '] + label\n",
+    "    \n",
+    "    sizes = []\n",
+    "    for line in label:\n",
+    "        sizes.append(cv2.getTextSize(line, FONTFACE, FONTSCALE, THICKNESS)[0])\n",
+    "    box_width = max([x[0] for x in sizes]) + 10\n",
+    "    text_height = sizes[0][1]\n",
+    "    box_height = len(sizes) * (text_height + 6)\n",
+    "    \n",
+    "    cv2.rectangle(frame, (0, 0), (box_width, box_height), BGCOLOR, -1)\n",
+    "    for i, line in enumerate(label):\n",
+    "        location = (5, (text_height + 6) * i + text_height + 3)\n",
+    "        cv2.putText(frame, line, location, FONTFACE, FONTSCALE, FONTCOLOR, THICKNESS, LINETYPE)\n",
+    "    return frame\n",
+    "    \n",
+    "\n",
+    "def vis_skeleton(vid_path, anno, category_name=None, ratio=0.5):\n",
+    "    vid = decord.VideoReader(vid_path)\n",
+    "    frames = [x.asnumpy() for x in vid]\n",
+    "    \n",
+    "    h, w, _ = frames[0].shape\n",
+    "    new_shape = (int(w * ratio), int(h * ratio))\n",
+    "    frames = [cv2.resize(f, new_shape) for f in frames]\n",
+    "    \n",
+    "    assert len(frames) == anno['total_frames']\n",
+    "    # The shape is N x T x K x 3\n",
+    "    kps = np.concatenate([anno['keypoint'], anno['keypoint_score'][..., None]], axis=-1)\n",
+    "    kps[..., :2] *= ratio\n",
+    "    # Convert to T x N x K x 3\n",
+    "    kps = kps.transpose([1, 0, 2, 3])\n",
+    "    vis_frames = []\n",
+    "\n",
+    "    # we need an instance of TopDown model, so build a minimal one\n",
+    "    model = TopDown(backbone=dict(type='ShuffleNetV1'))\n",
+    "\n",
+    "    for f, kp in zip(frames, kps):\n",
+    "        bbox = np.zeros([0, 4], dtype=np.float32)\n",
+    "        result = [dict(bbox=bbox, keypoints=k) for k in kp]\n",
+    "        vis_frame = vis_pose_result(model, f, result)\n",
+    "        \n",
+    "        if category_name is not None:\n",
+    "            vis_frame = add_label(vis_frame, category_name)\n",
+    "        \n",
+    "        vis_frames.append(vis_frame)\n",
+    "    return vis_frames"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 55,
+   "id": "applied-humanity",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "keypoint_pipeline = [\n",
+    "    dict(type='PoseDecode'),\n",
+    "    dict(type='PoseCompact', hw_ratio=1., allow_imgpad=True),\n",
+    "    dict(type='Resize', scale=(-1, 64)),\n",
+    "    dict(type='CenterCrop', crop_size=64),\n",
+    "    dict(type='GeneratePoseTarget', sigma=0.6, use_score=True, with_kp=True, with_limb=False)\n",
+    "]\n",
+    "\n",
+    "limb_pipeline = [\n",
+    "    dict(type='PoseDecode'),\n",
+    "    dict(type='PoseCompact', hw_ratio=1., allow_imgpad=True),\n",
+    "    dict(type='Resize', scale=(-1, 64)),\n",
+    "    dict(type='CenterCrop', crop_size=64),\n",
+    "    dict(type='GeneratePoseTarget', sigma=0.6, use_score=True, with_kp=False, with_limb=True)\n",
+    "]\n",
+    "\n",
+    "from mmaction.datasets.pipelines import Compose\n",
+    "def get_pseudo_heatmap(anno, flag='keypoint'):\n",
+    "    assert flag in ['keypoint', 'limb']\n",
+    "    pipeline = Compose(keypoint_pipeline if flag == 'keypoint' else limb_pipeline)\n",
+    "    return pipeline(anno)['imgs']\n",
+    "\n",
+    "def vis_heatmaps(heatmaps, channel=-1, ratio=8):\n",
+    "    # if channel is -1, draw all keypoints / limbs on the same map\n",
+    "    import matplotlib.cm as cm\n",
+    "    h, w, _ = heatmaps[0].shape\n",
+    "    newh, neww = int(h * ratio), int(w * ratio)\n",
+    "    \n",
+    "    if channel == -1:\n",
+    "        heatmaps = [np.max(x, axis=-1) for x in heatmaps]\n",
+    "    cmap = cm.viridis\n",
+    "    heatmaps = [(cmap(x)[..., :3] * 255).astype(np.uint8) for x in heatmaps]\n",
+    "    heatmaps = [cv2.resize(x, (neww, newh)) for x in heatmaps]\n",
+    "    return heatmaps"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "id": "automatic-commons",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Load GYM annotations\n",
+    "lines = list(urllib.request.urlopen('https://sdolivia.github.io/FineGym/resources/dataset/gym99_categories.txt'))\n",
+    "gym_categories = [x.decode().strip().split('; ')[-1] for x in lines]\n",
+    "gym_annos = load(gym_train_ann_file) + load(gym_val_ann_file)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 74,
+   "id": "numerous-bristol",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "--2021-04-25 22:18:53--  https://download.openmmlab.com/mmaction/posec3d/gym_samples.tar\n",
+      "Resolving download.openmmlab.com (download.openmmlab.com)... 124.160.145.22\n",
+      "Connecting to download.openmmlab.com (download.openmmlab.com)|124.160.145.22|:443... connected.\n",
+      "HTTP request sent, awaiting response... 200 OK\n",
+      "Length: 36300800 (35M) [application/x-tar]\n",
+      "Saving to: ‘gym_samples.tar’\n",
+      "\n",
+      "100%[======================================>] 36,300,800  11.5MB/s   in 3.0s   \n",
+      "\n",
+      "2021-04-25 22:18:58 (11.5 MB/s) - ‘gym_samples.tar’ saved [36300800/36300800]\n",
+      "\n"
+     ]
+    }
+   ],
+   "source": [
+    "# download sample videos of GYM\n",
+    "!wget https://download.openmmlab.com/mmaction/posec3d/gym_samples.tar\n",
+    "!tar -xf gym_samples.tar\n",
+    "!rm gym_samples.tar"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 76,
+   "id": "ranging-harrison",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "gym_root = 'gym_samples/'\n",
+    "gym_vids = os.listdir(gym_root)\n",
+    "# visualize pose of which video? index in 0 - 50.\n",
+    "idx = 1\n",
+    "vid = gym_vids[idx]\n",
+    "\n",
+    "frame_dir = vid.split('.')[0]\n",
+    "vid_path = osp.join(gym_root, vid)\n",
+    "anno = [x for x in gym_annos if x['frame_dir'] == frame_dir][0]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 86,
+   "id": "fitting-courage",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Visualize Skeleton\n",
+    "vis_frames = vis_skeleton(vid_path, anno, gym_categories[anno['label']])\n",
+    "vid = mpy.ImageSequenceClip(vis_frames, fps=24)\n",
+    "vid.ipython_display()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 87,
+   "id": "orange-logging",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "keypoint_heatmap = get_pseudo_heatmap(anno)\n",
+    "keypoint_mapvis = vis_heatmaps(keypoint_heatmap)\n",
+    "keypoint_mapvis = [add_label(f, gym_categories[anno['label']]) for f in keypoint_mapvis]\n",
+    "vid = mpy.ImageSequenceClip(keypoint_mapvis, fps=24)\n",
+    "vid.ipython_display()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 88,
+   "id": "residential-conjunction",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "limb_heatmap = get_pseudo_heatmap(anno, 'limb')\n",
+    "limb_mapvis = vis_heatmaps(limb_heatmap)\n",
+    "limb_mapvis = [add_label(f, gym_categories[anno['label']]) for f in limb_mapvis]\n",
+    "vid = mpy.ImageSequenceClip(limb_mapvis, fps=24)\n",
+    "vid.ipython_display()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 66,
+   "id": "coupled-stranger",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# The name list of \n",
+    "ntu_categories = ['drink water', 'eat meal/snack', 'brushing teeth', 'brushing hair', 'drop', 'pickup', \n",
+    "                  'throw', 'sitting down', 'standing up (from sitting position)', 'clapping', 'reading', \n",
+    "                  'writing', 'tear up paper', 'wear jacket', 'take off jacket', 'wear a shoe', \n",
+    "                  'take off a shoe', 'wear on glasses', 'take off glasses', 'put on a hat/cap', \n",
+    "                  'take off a hat/cap', 'cheer up', 'hand waving', 'kicking something', \n",
+    "                  'reach into pocket', 'hopping (one foot jumping)', 'jump up', \n",
+    "                  'make a phone call/answer phone', 'playing with phone/tablet', 'typing on a keyboard', \n",
+    "                  'pointing to something with finger', 'taking a selfie', 'check time (from watch)', \n",
+    "                  'rub two hands together', 'nod head/bow', 'shake head', 'wipe face', 'salute', \n",
+    "                  'put the palms together', 'cross hands in front (say stop)', 'sneeze/cough', \n",
+    "                  'staggering', 'falling', 'touch head (headache)', 'touch chest (stomachache/heart pain)', \n",
+    "                  'touch back (backache)', 'touch neck (neckache)', 'nausea or vomiting condition', \n",
+    "                  'use a fan (with hand or paper)/feeling warm', 'punching/slapping other person', \n",
+    "                  'kicking other person', 'pushing other person', 'pat on back of other person', \n",
+    "                  'point finger at the other person', 'hugging other person', \n",
+    "                  'giving something to other person', \"touch other person's pocket\", 'handshaking', \n",
+    "                  'walking towards each other', 'walking apart from each other']\n",
+    "ntu_annos = load(ntu60_xsub_train_ann_file) + load(ntu60_xsub_val_ann_file)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 80,
+   "id": "critical-review",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "ntu_root = 'ntu_samples/'\n",
+    "ntu_vids = os.listdir(ntu_root)\n",
+    "# visualize pose of which video? index in 0 - 50.\n",
+    "idx = 20\n",
+    "vid = ntu_vids[idx]\n",
+    "\n",
+    "frame_dir = vid.split('.')[0]\n",
+    "vid_path = osp.join(ntu_root, vid)\n",
+    "anno = [x for x in ntu_annos if x['frame_dir'] == frame_dir.split('_')[0]][0]\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 81,
+   "id": "seasonal-palmer",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "--2021-04-25 22:21:16--  https://download.openmmlab.com/mmaction/posec3d/ntu_samples.tar\n",
+      "Resolving download.openmmlab.com (download.openmmlab.com)... 124.160.145.22\n",
+      "Connecting to download.openmmlab.com (download.openmmlab.com)|124.160.145.22|:443... connected.\n",
+      "HTTP request sent, awaiting response... 200 OK\n",
+      "Length: 121753600 (116M) [application/x-tar]\n",
+      "Saving to: ‘ntu_samples.tar’\n",
+      "\n",
+      "100%[======================================>] 121,753,600 14.4MB/s   in 9.2s   \n",
+      "\n",
+      "2021-04-25 22:21:26 (12.6 MB/s) - ‘ntu_samples.tar’ saved [121753600/121753600]\n",
+      "\n"
+     ]
+    }
+   ],
+   "source": [
+    "# download sample videos of NTU-60\n",
+    "!wget https://download.openmmlab.com/mmaction/posec3d/ntu_samples.tar\n",
+    "!tar -xf ntu_samples.tar\n",
+    "!rm ntu_samples.tar"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 89,
+   "id": "accompanied-invitation",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "vis_frames = vis_skeleton(vid_path, anno, ntu_categories[anno['label']])\n",
+    "vid = mpy.ImageSequenceClip(vis_frames, fps=24)\n",
+    "vid.ipython_display()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 90,
+   "id": "respiratory-conclusion",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "keypoint_heatmap = get_pseudo_heatmap(anno)\n",
+    "keypoint_mapvis = vis_heatmaps(keypoint_heatmap)\n",
+    "keypoint_mapvis = [add_label(f, gym_categories[anno['label']]) for f in keypoint_mapvis]\n",
+    "vid = mpy.ImageSequenceClip(keypoint_mapvis, fps=24)\n",
+    "vid.ipython_display()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 91,
+   "id": "thirty-vancouver",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "limb_heatmap = get_pseudo_heatmap(anno, 'limb')\n",
+    "limb_mapvis = vis_heatmaps(limb_heatmap)\n",
+    "limb_mapvis = [add_label(f, gym_categories[anno['label']]) for f in limb_mapvis]\n",
+    "vid = mpy.ImageSequenceClip(limb_mapvis, fps=24)\n",
+    "vid.ipython_display()"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.12"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}