4167 lines (4166 with data), 259.6 kB
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import TensorDataset\n",
"import torch.optim as optim\n",
"from torch.optim import lr_scheduler\n",
"import numpy as np\n",
"import torchvision\n",
"import torch.nn.functional as F\n",
"from torch.utils.data.sampler import SubsetRandomSampler\n",
"from torch.utils.data import DataLoader\n",
"from torchvision import datasets, models, transforms\n",
"from torchvision.transforms import Resize, ToTensor, Normalize\n",
"import matplotlib.pyplot as plt\n",
"from imblearn.under_sampling import RandomUnderSampler\n",
"\n",
"from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, roc_auc_score, \\\n",
" average_precision_score\n",
"from sklearn.model_selection import train_test_split\n",
"import time\n",
"import os\n",
"from pathlib import Path\n",
"from skimage import io\n",
"import copy\n",
"from torch import optim, cuda\n",
"import pandas as pd\n",
"import glob\n",
"from collections import Counter\n",
"# Useful for examining network\n",
"from functools import reduce\n",
"from operator import __add__\n",
"# from torchsummary import summary\n",
"import seaborn as sns\n",
"import warnings\n",
"# warnings.filterwarnings('ignore', category=FutureWarning)\n",
"from PIL import Image\n",
"from timeit import default_timer as timer\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Useful for examining network\n",
"from functools import reduce\n",
"from operator import __add__\n",
"from torchsummary import summary\n",
"\n",
"# from IPython.core.interactiveshell import InteractiveShell\n",
"import seaborn as sns\n",
"\n",
"import warnings\n",
"# warnings.filterwarnings('ignore', category=FutureWarning)\n",
"\n",
"# Image manipulations\n",
"from PIL import Image\n",
"\n",
"# Timing utility\n",
"from timeit import default_timer as timer\n",
"\n",
"# Visualizations\n",
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"class Multi_2D_CNN_block(nn.Module):\n",
" def __init__(self, in_channels, num_kernel):\n",
" super(Multi_2D_CNN_block, self).__init__()\n",
" conv_block = BasicConv2d\n",
" self.a = conv_block(in_channels, int(num_kernel / 3), kernel_size=(1, 1))\n",
"\n",
" self.b = nn.Sequential(\n",
" conv_block(in_channels, int(num_kernel / 2), kernel_size=(1, 1)),\n",
" conv_block(int(num_kernel / 2), int(num_kernel), kernel_size=(3, 3))\n",
" )\n",
"\n",
" self.c = nn.Sequential(\n",
" conv_block(in_channels, int(num_kernel / 3), kernel_size=(1, 1)),\n",
" conv_block(int(num_kernel / 3), int(num_kernel / 2), kernel_size=(3, 3)),\n",
" conv_block(int(num_kernel / 2), int(num_kernel), kernel_size=(3, 3))\n",
" )\n",
" self.out_channels = int(num_kernel / 3) + int(num_kernel) + int(num_kernel)\n",
" # I get out_channels is total number of out_channels for a/b/c\n",
" self.bn = nn.BatchNorm2d(self.out_channels)\n",
"\n",
" def get_out_channels(self):\n",
" return self.out_channels\n",
"\n",
" def forward(self, x):\n",
" branch1 = self.a(x)\n",
" branch2 = self.b(x)\n",
" branch3 = self.c(x)\n",
" output = [branch1, branch2, branch3]\n",
" return self.bn(torch.cat(output,\n",
" 1)) # BatchNorm across the concatenation of output channels from final layer of Branch 1/2/3\n",
" # ,1 refers to the channel dimension\n",
"\n",
"\n",
"class BasicConv2d(nn.Module):\n",
" def __init__(self, in_channels, out_channels, kernel_size, **kwargs):\n",
" super(BasicConv2d, self).__init__()\n",
" conv_padding = reduce(__add__, [(k // 2 + (k - 2 * (k // 2)) - 1, k // 2) for k in kernel_size[::-1]])\n",
" self.pad = nn.ZeroPad2d(conv_padding)\n",
" # ZeroPad2d Output: :math:`(N, C, H_{out}, W_{out})` H_{out} is H_{in} with the padding to be added to either side of height\n",
" # ZeroPad2d(2) would add 2 to all 4 sides, ZeroPad2d((1,1,2,0)) would add 1 left, 1 right, 2 above, 0 below\n",
" # n_output_features = floor((n_input_features + 2(paddingsize) - convkernel_size) / stride_size) + 1\n",
" # above creates same padding\n",
" self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=False, **kwargs)\n",
" self.bn = nn.BatchNorm2d(out_channels)\n",
"\n",
" def forward(self, x):\n",
" x = self.pad(x)\n",
" x = self.conv(x)\n",
" x = F.relu(x, inplace=True)\n",
" x = self.bn(x)\n",
" return x\n",
"\n",
"\n",
"class MyModel(nn.Module):\n",
"\n",
" def __init__(self, initial_kernel_num=64):\n",
" super(MyModel, self).__init__()\n",
"\n",
" multi_2d_cnn = Multi_2D_CNN_block\n",
" conv_block = BasicConv2d\n",
"\n",
" self.conv_1 = conv_block(1, 64, kernel_size=(7, 3), stride=(2, 1))\n",
"\n",
" self.multi_2d_cnn_1a = nn.Sequential(\n",
" multi_2d_cnn(in_channels=64, num_kernel=initial_kernel_num),\n",
" multi_2d_cnn(in_channels=149, num_kernel=initial_kernel_num),\n",
" nn.MaxPool2d(kernel_size=(3, 1))\n",
" )\n",
"\n",
" self.multi_2d_cnn_1b = nn.Sequential(\n",
" multi_2d_cnn(in_channels=149, num_kernel=initial_kernel_num * 1.5),\n",
" multi_2d_cnn(in_channels=224, num_kernel=initial_kernel_num * 1.5),\n",
" nn.MaxPool2d(kernel_size=(3, 1))\n",
" )\n",
" self.multi_2d_cnn_1c = nn.Sequential(\n",
" multi_2d_cnn(in_channels=224, num_kernel=initial_kernel_num * 2),\n",
" multi_2d_cnn(in_channels=298, num_kernel=initial_kernel_num * 2),\n",
" nn.MaxPool2d(kernel_size=(2, 1))\n",
" )\n",
"\n",
" self.multi_2d_cnn_2a = nn.Sequential(\n",
" multi_2d_cnn(in_channels=298, num_kernel=initial_kernel_num * 3),\n",
" multi_2d_cnn(in_channels=448, num_kernel=initial_kernel_num * 3),\n",
" multi_2d_cnn(in_channels=448, num_kernel=initial_kernel_num * 4),\n",
" nn.MaxPool2d(kernel_size=(2, 1))\n",
" )\n",
" self.multi_2d_cnn_2b = nn.Sequential(\n",
" multi_2d_cnn(in_channels=597, num_kernel=initial_kernel_num * 5),\n",
" multi_2d_cnn(in_channels=746, num_kernel=initial_kernel_num * 6),\n",
" multi_2d_cnn(in_channels=896, num_kernel=initial_kernel_num * 7),\n",
" nn.MaxPool2d(kernel_size=(2, 1))\n",
" )\n",
" self.multi_2d_cnn_2c = nn.Sequential(\n",
" multi_2d_cnn(in_channels=1045, num_kernel=initial_kernel_num * 8),\n",
" multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 8),\n",
" multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 8),\n",
" nn.MaxPool2d(kernel_size=(2, 1))\n",
" )\n",
" self.multi_2d_cnn_2d = nn.Sequential(\n",
" multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 12),\n",
" multi_2d_cnn(in_channels=1792, num_kernel=initial_kernel_num * 14),\n",
" multi_2d_cnn(in_channels=2090, num_kernel=initial_kernel_num * 16),\n",
" )\n",
" self.output = nn.Sequential(\n",
" nn.AdaptiveAvgPool2d((1, 1)),\n",
" nn.Flatten(),\n",
" nn.Dropout(0.5),\n",
" nn.Linear(2389, 1),\n",
" # nn.Sigmoid()\n",
" )\n",
"\n",
" def forward(self, x):\n",
" x = self.conv_1(x)\n",
" # N x 1250 x 12 x 64 tensor\n",
" x = self.multi_2d_cnn_1a(x)\n",
" # N x 416 x 12 x 149 tensor\n",
" x = self.multi_2d_cnn_1b(x)\n",
" # N x 138 x 12 x 224 tensor\n",
" x = self.multi_2d_cnn_1c(x)\n",
" # N x 69 x 12 x 298\n",
" x = self.multi_2d_cnn_2a(x)\n",
"\n",
" x = self.multi_2d_cnn_2b(x)\n",
"\n",
" x = self.multi_2d_cnn_2c(x)\n",
"\n",
" x = self.multi_2d_cnn_2d(x)\n",
"\n",
" x = self.output(x)\n",
"\n",
" return x\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on: cpu\n"
]
}
],
"source": [
"\n",
"model = MyModel()\n",
"\n",
"# Device check\n",
"device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f'Train on: {device}')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" ZeroPad2d-1 [-1, 1, 2506, 14] 0\n",
" Conv2d-2 [-1, 64, 1250, 12] 1,344\n",
" BatchNorm2d-3 [-1, 64, 1250, 12] 128\n",
" BasicConv2d-4 [-1, 64, 1250, 12] 0\n",
" ZeroPad2d-5 [-1, 64, 1250, 12] 0\n",
" Conv2d-6 [-1, 21, 1250, 12] 1,344\n",
" BatchNorm2d-7 [-1, 21, 1250, 12] 42\n",
" BasicConv2d-8 [-1, 21, 1250, 12] 0\n",
" ZeroPad2d-9 [-1, 64, 1250, 12] 0\n",
" Conv2d-10 [-1, 32, 1250, 12] 2,048\n",
" BatchNorm2d-11 [-1, 32, 1250, 12] 64\n",
" BasicConv2d-12 [-1, 32, 1250, 12] 0\n",
" ZeroPad2d-13 [-1, 32, 1252, 14] 0\n",
" Conv2d-14 [-1, 64, 1250, 12] 18,432\n",
" BatchNorm2d-15 [-1, 64, 1250, 12] 128\n",
" BasicConv2d-16 [-1, 64, 1250, 12] 0\n",
" ZeroPad2d-17 [-1, 64, 1250, 12] 0\n",
" Conv2d-18 [-1, 21, 1250, 12] 1,344\n",
" BatchNorm2d-19 [-1, 21, 1250, 12] 42\n",
" BasicConv2d-20 [-1, 21, 1250, 12] 0\n",
" ZeroPad2d-21 [-1, 21, 1252, 14] 0\n",
" Conv2d-22 [-1, 32, 1250, 12] 6,048\n",
" BatchNorm2d-23 [-1, 32, 1250, 12] 64\n",
" BasicConv2d-24 [-1, 32, 1250, 12] 0\n",
" ZeroPad2d-25 [-1, 32, 1252, 14] 0\n",
" Conv2d-26 [-1, 64, 1250, 12] 18,432\n",
" BatchNorm2d-27 [-1, 64, 1250, 12] 128\n",
" BasicConv2d-28 [-1, 64, 1250, 12] 0\n",
" BatchNorm2d-29 [-1, 149, 1250, 12] 298\n",
"Multi_2D_CNN_block-30 [-1, 149, 1250, 12] 0\n",
" ZeroPad2d-31 [-1, 149, 1250, 12] 0\n",
" Conv2d-32 [-1, 21, 1250, 12] 3,129\n",
" BatchNorm2d-33 [-1, 21, 1250, 12] 42\n",
" BasicConv2d-34 [-1, 21, 1250, 12] 0\n",
" ZeroPad2d-35 [-1, 149, 1250, 12] 0\n",
" Conv2d-36 [-1, 32, 1250, 12] 4,768\n",
" BatchNorm2d-37 [-1, 32, 1250, 12] 64\n",
" BasicConv2d-38 [-1, 32, 1250, 12] 0\n",
" ZeroPad2d-39 [-1, 32, 1252, 14] 0\n",
" Conv2d-40 [-1, 64, 1250, 12] 18,432\n",
" BatchNorm2d-41 [-1, 64, 1250, 12] 128\n",
" BasicConv2d-42 [-1, 64, 1250, 12] 0\n",
" ZeroPad2d-43 [-1, 149, 1250, 12] 0\n",
" Conv2d-44 [-1, 21, 1250, 12] 3,129\n",
" BatchNorm2d-45 [-1, 21, 1250, 12] 42\n",
" BasicConv2d-46 [-1, 21, 1250, 12] 0\n",
" ZeroPad2d-47 [-1, 21, 1252, 14] 0\n",
" Conv2d-48 [-1, 32, 1250, 12] 6,048\n",
" BatchNorm2d-49 [-1, 32, 1250, 12] 64\n",
" BasicConv2d-50 [-1, 32, 1250, 12] 0\n",
" ZeroPad2d-51 [-1, 32, 1252, 14] 0\n",
" Conv2d-52 [-1, 64, 1250, 12] 18,432\n",
" BatchNorm2d-53 [-1, 64, 1250, 12] 128\n",
" BasicConv2d-54 [-1, 64, 1250, 12] 0\n",
" BatchNorm2d-55 [-1, 149, 1250, 12] 298\n",
"Multi_2D_CNN_block-56 [-1, 149, 1250, 12] 0\n",
" MaxPool2d-57 [-1, 149, 416, 12] 0\n",
" ZeroPad2d-58 [-1, 149, 416, 12] 0\n",
" Conv2d-59 [-1, 32, 416, 12] 4,768\n",
" BatchNorm2d-60 [-1, 32, 416, 12] 64\n",
" BasicConv2d-61 [-1, 32, 416, 12] 0\n",
" ZeroPad2d-62 [-1, 149, 416, 12] 0\n",
" Conv2d-63 [-1, 48, 416, 12] 7,152\n",
" BatchNorm2d-64 [-1, 48, 416, 12] 96\n",
" BasicConv2d-65 [-1, 48, 416, 12] 0\n",
" ZeroPad2d-66 [-1, 48, 418, 14] 0\n",
" Conv2d-67 [-1, 96, 416, 12] 41,472\n",
" BatchNorm2d-68 [-1, 96, 416, 12] 192\n",
" BasicConv2d-69 [-1, 96, 416, 12] 0\n",
" ZeroPad2d-70 [-1, 149, 416, 12] 0\n",
" Conv2d-71 [-1, 32, 416, 12] 4,768\n",
" BatchNorm2d-72 [-1, 32, 416, 12] 64\n",
" BasicConv2d-73 [-1, 32, 416, 12] 0\n",
" ZeroPad2d-74 [-1, 32, 418, 14] 0\n",
" Conv2d-75 [-1, 48, 416, 12] 13,824\n",
" BatchNorm2d-76 [-1, 48, 416, 12] 96\n",
" BasicConv2d-77 [-1, 48, 416, 12] 0\n",
" ZeroPad2d-78 [-1, 48, 418, 14] 0\n",
" Conv2d-79 [-1, 96, 416, 12] 41,472\n",
" BatchNorm2d-80 [-1, 96, 416, 12] 192\n",
" BasicConv2d-81 [-1, 96, 416, 12] 0\n",
" BatchNorm2d-82 [-1, 224, 416, 12] 448\n",
"Multi_2D_CNN_block-83 [-1, 224, 416, 12] 0\n",
" ZeroPad2d-84 [-1, 224, 416, 12] 0\n",
" Conv2d-85 [-1, 32, 416, 12] 7,168\n",
" BatchNorm2d-86 [-1, 32, 416, 12] 64\n",
" BasicConv2d-87 [-1, 32, 416, 12] 0\n",
" ZeroPad2d-88 [-1, 224, 416, 12] 0\n",
" Conv2d-89 [-1, 48, 416, 12] 10,752\n",
" BatchNorm2d-90 [-1, 48, 416, 12] 96\n",
" BasicConv2d-91 [-1, 48, 416, 12] 0\n",
" ZeroPad2d-92 [-1, 48, 418, 14] 0\n",
" Conv2d-93 [-1, 96, 416, 12] 41,472\n",
" BatchNorm2d-94 [-1, 96, 416, 12] 192\n",
" BasicConv2d-95 [-1, 96, 416, 12] 0\n",
" ZeroPad2d-96 [-1, 224, 416, 12] 0\n",
" Conv2d-97 [-1, 32, 416, 12] 7,168\n",
" BatchNorm2d-98 [-1, 32, 416, 12] 64\n",
" BasicConv2d-99 [-1, 32, 416, 12] 0\n",
" ZeroPad2d-100 [-1, 32, 418, 14] 0\n",
" Conv2d-101 [-1, 48, 416, 12] 13,824\n",
" BatchNorm2d-102 [-1, 48, 416, 12] 96\n",
" BasicConv2d-103 [-1, 48, 416, 12] 0\n",
" ZeroPad2d-104 [-1, 48, 418, 14] 0\n",
" Conv2d-105 [-1, 96, 416, 12] 41,472\n",
" BatchNorm2d-106 [-1, 96, 416, 12] 192\n",
" BasicConv2d-107 [-1, 96, 416, 12] 0\n",
" BatchNorm2d-108 [-1, 224, 416, 12] 448\n",
"Multi_2D_CNN_block-109 [-1, 224, 416, 12] 0\n",
" MaxPool2d-110 [-1, 224, 138, 12] 0\n",
" ZeroPad2d-111 [-1, 224, 138, 12] 0\n",
" Conv2d-112 [-1, 42, 138, 12] 9,408\n",
" BatchNorm2d-113 [-1, 42, 138, 12] 84\n",
" BasicConv2d-114 [-1, 42, 138, 12] 0\n",
" ZeroPad2d-115 [-1, 224, 138, 12] 0\n",
" Conv2d-116 [-1, 64, 138, 12] 14,336\n",
" BatchNorm2d-117 [-1, 64, 138, 12] 128\n",
" BasicConv2d-118 [-1, 64, 138, 12] 0\n",
" ZeroPad2d-119 [-1, 64, 140, 14] 0\n",
" Conv2d-120 [-1, 128, 138, 12] 73,728\n",
" BatchNorm2d-121 [-1, 128, 138, 12] 256\n",
" BasicConv2d-122 [-1, 128, 138, 12] 0\n",
" ZeroPad2d-123 [-1, 224, 138, 12] 0\n",
" Conv2d-124 [-1, 42, 138, 12] 9,408\n",
" BatchNorm2d-125 [-1, 42, 138, 12] 84\n",
" BasicConv2d-126 [-1, 42, 138, 12] 0\n",
" ZeroPad2d-127 [-1, 42, 140, 14] 0\n",
" Conv2d-128 [-1, 64, 138, 12] 24,192\n",
" BatchNorm2d-129 [-1, 64, 138, 12] 128\n",
" BasicConv2d-130 [-1, 64, 138, 12] 0\n",
" ZeroPad2d-131 [-1, 64, 140, 14] 0\n",
" Conv2d-132 [-1, 128, 138, 12] 73,728\n",
" BatchNorm2d-133 [-1, 128, 138, 12] 256\n",
" BasicConv2d-134 [-1, 128, 138, 12] 0\n",
" BatchNorm2d-135 [-1, 298, 138, 12] 596\n",
"Multi_2D_CNN_block-136 [-1, 298, 138, 12] 0\n",
" ZeroPad2d-137 [-1, 298, 138, 12] 0\n",
" Conv2d-138 [-1, 42, 138, 12] 12,516\n",
" BatchNorm2d-139 [-1, 42, 138, 12] 84\n",
" BasicConv2d-140 [-1, 42, 138, 12] 0\n",
" ZeroPad2d-141 [-1, 298, 138, 12] 0\n",
" Conv2d-142 [-1, 64, 138, 12] 19,072\n",
" BatchNorm2d-143 [-1, 64, 138, 12] 128\n",
" BasicConv2d-144 [-1, 64, 138, 12] 0\n",
" ZeroPad2d-145 [-1, 64, 140, 14] 0\n",
" Conv2d-146 [-1, 128, 138, 12] 73,728\n",
" BatchNorm2d-147 [-1, 128, 138, 12] 256\n",
" BasicConv2d-148 [-1, 128, 138, 12] 0\n",
" ZeroPad2d-149 [-1, 298, 138, 12] 0\n",
" Conv2d-150 [-1, 42, 138, 12] 12,516\n",
" BatchNorm2d-151 [-1, 42, 138, 12] 84\n",
" BasicConv2d-152 [-1, 42, 138, 12] 0\n",
" ZeroPad2d-153 [-1, 42, 140, 14] 0\n",
" Conv2d-154 [-1, 64, 138, 12] 24,192\n",
" BatchNorm2d-155 [-1, 64, 138, 12] 128\n",
" BasicConv2d-156 [-1, 64, 138, 12] 0\n",
" ZeroPad2d-157 [-1, 64, 140, 14] 0\n",
" Conv2d-158 [-1, 128, 138, 12] 73,728\n",
" BatchNorm2d-159 [-1, 128, 138, 12] 256\n",
" BasicConv2d-160 [-1, 128, 138, 12] 0\n",
" BatchNorm2d-161 [-1, 298, 138, 12] 596\n",
"Multi_2D_CNN_block-162 [-1, 298, 138, 12] 0\n",
" MaxPool2d-163 [-1, 298, 69, 12] 0\n",
" ZeroPad2d-164 [-1, 298, 69, 12] 0\n",
" Conv2d-165 [-1, 64, 69, 12] 19,072\n",
" BatchNorm2d-166 [-1, 64, 69, 12] 128\n",
" BasicConv2d-167 [-1, 64, 69, 12] 0\n",
" ZeroPad2d-168 [-1, 298, 69, 12] 0\n",
" Conv2d-169 [-1, 96, 69, 12] 28,608\n",
" BatchNorm2d-170 [-1, 96, 69, 12] 192\n",
" BasicConv2d-171 [-1, 96, 69, 12] 0\n",
" ZeroPad2d-172 [-1, 96, 71, 14] 0\n",
" Conv2d-173 [-1, 192, 69, 12] 165,888\n",
" BatchNorm2d-174 [-1, 192, 69, 12] 384\n",
" BasicConv2d-175 [-1, 192, 69, 12] 0\n",
" ZeroPad2d-176 [-1, 298, 69, 12] 0\n",
" Conv2d-177 [-1, 64, 69, 12] 19,072\n",
" BatchNorm2d-178 [-1, 64, 69, 12] 128\n",
" BasicConv2d-179 [-1, 64, 69, 12] 0\n",
" ZeroPad2d-180 [-1, 64, 71, 14] 0\n",
" Conv2d-181 [-1, 96, 69, 12] 55,296\n",
" BatchNorm2d-182 [-1, 96, 69, 12] 192\n",
" BasicConv2d-183 [-1, 96, 69, 12] 0\n",
" ZeroPad2d-184 [-1, 96, 71, 14] 0\n",
" Conv2d-185 [-1, 192, 69, 12] 165,888\n",
" BatchNorm2d-186 [-1, 192, 69, 12] 384\n",
" BasicConv2d-187 [-1, 192, 69, 12] 0\n",
" BatchNorm2d-188 [-1, 448, 69, 12] 896\n",
"Multi_2D_CNN_block-189 [-1, 448, 69, 12] 0\n",
" ZeroPad2d-190 [-1, 448, 69, 12] 0\n",
" Conv2d-191 [-1, 64, 69, 12] 28,672\n",
" BatchNorm2d-192 [-1, 64, 69, 12] 128\n",
" BasicConv2d-193 [-1, 64, 69, 12] 0\n",
" ZeroPad2d-194 [-1, 448, 69, 12] 0\n",
" Conv2d-195 [-1, 96, 69, 12] 43,008\n",
" BatchNorm2d-196 [-1, 96, 69, 12] 192\n",
" BasicConv2d-197 [-1, 96, 69, 12] 0\n",
" ZeroPad2d-198 [-1, 96, 71, 14] 0\n",
" Conv2d-199 [-1, 192, 69, 12] 165,888\n",
" BatchNorm2d-200 [-1, 192, 69, 12] 384\n",
" BasicConv2d-201 [-1, 192, 69, 12] 0\n",
" ZeroPad2d-202 [-1, 448, 69, 12] 0\n",
" Conv2d-203 [-1, 64, 69, 12] 28,672\n",
" BatchNorm2d-204 [-1, 64, 69, 12] 128\n",
" BasicConv2d-205 [-1, 64, 69, 12] 0\n",
" ZeroPad2d-206 [-1, 64, 71, 14] 0\n",
" Conv2d-207 [-1, 96, 69, 12] 55,296\n",
" BatchNorm2d-208 [-1, 96, 69, 12] 192\n",
" BasicConv2d-209 [-1, 96, 69, 12] 0\n",
" ZeroPad2d-210 [-1, 96, 71, 14] 0\n",
" Conv2d-211 [-1, 192, 69, 12] 165,888\n",
" BatchNorm2d-212 [-1, 192, 69, 12] 384\n",
" BasicConv2d-213 [-1, 192, 69, 12] 0\n",
" BatchNorm2d-214 [-1, 448, 69, 12] 896\n",
"Multi_2D_CNN_block-215 [-1, 448, 69, 12] 0\n",
" ZeroPad2d-216 [-1, 448, 69, 12] 0\n",
" Conv2d-217 [-1, 85, 69, 12] 38,080\n",
" BatchNorm2d-218 [-1, 85, 69, 12] 170\n",
" BasicConv2d-219 [-1, 85, 69, 12] 0\n",
" ZeroPad2d-220 [-1, 448, 69, 12] 0\n",
" Conv2d-221 [-1, 128, 69, 12] 57,344\n",
" BatchNorm2d-222 [-1, 128, 69, 12] 256\n",
" BasicConv2d-223 [-1, 128, 69, 12] 0\n",
" ZeroPad2d-224 [-1, 128, 71, 14] 0\n",
" Conv2d-225 [-1, 256, 69, 12] 294,912\n",
" BatchNorm2d-226 [-1, 256, 69, 12] 512\n",
" BasicConv2d-227 [-1, 256, 69, 12] 0\n",
" ZeroPad2d-228 [-1, 448, 69, 12] 0\n",
" Conv2d-229 [-1, 85, 69, 12] 38,080\n",
" BatchNorm2d-230 [-1, 85, 69, 12] 170\n",
" BasicConv2d-231 [-1, 85, 69, 12] 0\n",
" ZeroPad2d-232 [-1, 85, 71, 14] 0\n",
" Conv2d-233 [-1, 128, 69, 12] 97,920\n",
" BatchNorm2d-234 [-1, 128, 69, 12] 256\n",
" BasicConv2d-235 [-1, 128, 69, 12] 0\n",
" ZeroPad2d-236 [-1, 128, 71, 14] 0\n",
" Conv2d-237 [-1, 256, 69, 12] 294,912\n",
" BatchNorm2d-238 [-1, 256, 69, 12] 512\n",
" BasicConv2d-239 [-1, 256, 69, 12] 0\n",
" BatchNorm2d-240 [-1, 597, 69, 12] 1,194\n",
"Multi_2D_CNN_block-241 [-1, 597, 69, 12] 0\n",
" MaxPool2d-242 [-1, 597, 34, 12] 0\n",
" ZeroPad2d-243 [-1, 597, 34, 12] 0\n",
" Conv2d-244 [-1, 106, 34, 12] 63,282\n",
" BatchNorm2d-245 [-1, 106, 34, 12] 212\n",
" BasicConv2d-246 [-1, 106, 34, 12] 0\n",
" ZeroPad2d-247 [-1, 597, 34, 12] 0\n",
" Conv2d-248 [-1, 160, 34, 12] 95,520\n",
" BatchNorm2d-249 [-1, 160, 34, 12] 320\n",
" BasicConv2d-250 [-1, 160, 34, 12] 0\n",
" ZeroPad2d-251 [-1, 160, 36, 14] 0\n",
" Conv2d-252 [-1, 320, 34, 12] 460,800\n",
" BatchNorm2d-253 [-1, 320, 34, 12] 640\n",
" BasicConv2d-254 [-1, 320, 34, 12] 0\n",
" ZeroPad2d-255 [-1, 597, 34, 12] 0\n",
" Conv2d-256 [-1, 106, 34, 12] 63,282\n",
" BatchNorm2d-257 [-1, 106, 34, 12] 212\n",
" BasicConv2d-258 [-1, 106, 34, 12] 0\n",
" ZeroPad2d-259 [-1, 106, 36, 14] 0\n",
" Conv2d-260 [-1, 160, 34, 12] 152,640\n",
" BatchNorm2d-261 [-1, 160, 34, 12] 320\n",
" BasicConv2d-262 [-1, 160, 34, 12] 0\n",
" ZeroPad2d-263 [-1, 160, 36, 14] 0\n",
" Conv2d-264 [-1, 320, 34, 12] 460,800\n",
" BatchNorm2d-265 [-1, 320, 34, 12] 640\n",
" BasicConv2d-266 [-1, 320, 34, 12] 0\n",
" BatchNorm2d-267 [-1, 746, 34, 12] 1,492\n",
"Multi_2D_CNN_block-268 [-1, 746, 34, 12] 0\n",
" ZeroPad2d-269 [-1, 746, 34, 12] 0\n",
" Conv2d-270 [-1, 128, 34, 12] 95,488\n",
" BatchNorm2d-271 [-1, 128, 34, 12] 256\n",
" BasicConv2d-272 [-1, 128, 34, 12] 0\n",
" ZeroPad2d-273 [-1, 746, 34, 12] 0\n",
" Conv2d-274 [-1, 192, 34, 12] 143,232\n",
" BatchNorm2d-275 [-1, 192, 34, 12] 384\n",
" BasicConv2d-276 [-1, 192, 34, 12] 0\n",
" ZeroPad2d-277 [-1, 192, 36, 14] 0\n",
" Conv2d-278 [-1, 384, 34, 12] 663,552\n",
" BatchNorm2d-279 [-1, 384, 34, 12] 768\n",
" BasicConv2d-280 [-1, 384, 34, 12] 0\n",
" ZeroPad2d-281 [-1, 746, 34, 12] 0\n",
" Conv2d-282 [-1, 128, 34, 12] 95,488\n",
" BatchNorm2d-283 [-1, 128, 34, 12] 256\n",
" BasicConv2d-284 [-1, 128, 34, 12] 0\n",
" ZeroPad2d-285 [-1, 128, 36, 14] 0\n",
" Conv2d-286 [-1, 192, 34, 12] 221,184\n",
" BatchNorm2d-287 [-1, 192, 34, 12] 384\n",
" BasicConv2d-288 [-1, 192, 34, 12] 0\n",
" ZeroPad2d-289 [-1, 192, 36, 14] 0\n",
" Conv2d-290 [-1, 384, 34, 12] 663,552\n",
" BatchNorm2d-291 [-1, 384, 34, 12] 768\n",
" BasicConv2d-292 [-1, 384, 34, 12] 0\n",
" BatchNorm2d-293 [-1, 896, 34, 12] 1,792\n",
"Multi_2D_CNN_block-294 [-1, 896, 34, 12] 0\n",
" ZeroPad2d-295 [-1, 896, 34, 12] 0\n",
" Conv2d-296 [-1, 149, 34, 12] 133,504\n",
" BatchNorm2d-297 [-1, 149, 34, 12] 298\n",
" BasicConv2d-298 [-1, 149, 34, 12] 0\n",
" ZeroPad2d-299 [-1, 896, 34, 12] 0\n",
" Conv2d-300 [-1, 224, 34, 12] 200,704\n",
" BatchNorm2d-301 [-1, 224, 34, 12] 448\n",
" BasicConv2d-302 [-1, 224, 34, 12] 0\n",
" ZeroPad2d-303 [-1, 224, 36, 14] 0\n",
" Conv2d-304 [-1, 448, 34, 12] 903,168\n",
" BatchNorm2d-305 [-1, 448, 34, 12] 896\n",
" BasicConv2d-306 [-1, 448, 34, 12] 0\n",
" ZeroPad2d-307 [-1, 896, 34, 12] 0\n",
" Conv2d-308 [-1, 149, 34, 12] 133,504\n",
" BatchNorm2d-309 [-1, 149, 34, 12] 298\n",
" BasicConv2d-310 [-1, 149, 34, 12] 0\n",
" ZeroPad2d-311 [-1, 149, 36, 14] 0\n",
" Conv2d-312 [-1, 224, 34, 12] 300,384\n",
" BatchNorm2d-313 [-1, 224, 34, 12] 448\n",
" BasicConv2d-314 [-1, 224, 34, 12] 0\n",
" ZeroPad2d-315 [-1, 224, 36, 14] 0\n",
" Conv2d-316 [-1, 448, 34, 12] 903,168\n",
" BatchNorm2d-317 [-1, 448, 34, 12] 896\n",
" BasicConv2d-318 [-1, 448, 34, 12] 0\n",
" BatchNorm2d-319 [-1, 1045, 34, 12] 2,090\n",
"Multi_2D_CNN_block-320 [-1, 1045, 34, 12] 0\n",
" MaxPool2d-321 [-1, 1045, 17, 12] 0\n",
" ZeroPad2d-322 [-1, 1045, 17, 12] 0\n",
" Conv2d-323 [-1, 170, 17, 12] 177,650\n",
" BatchNorm2d-324 [-1, 170, 17, 12] 340\n",
" BasicConv2d-325 [-1, 170, 17, 12] 0\n",
" ZeroPad2d-326 [-1, 1045, 17, 12] 0\n",
" Conv2d-327 [-1, 256, 17, 12] 267,520\n",
" BatchNorm2d-328 [-1, 256, 17, 12] 512\n",
" BasicConv2d-329 [-1, 256, 17, 12] 0\n",
" ZeroPad2d-330 [-1, 256, 19, 14] 0\n",
" Conv2d-331 [-1, 512, 17, 12] 1,179,648\n",
" BatchNorm2d-332 [-1, 512, 17, 12] 1,024\n",
" BasicConv2d-333 [-1, 512, 17, 12] 0\n",
" ZeroPad2d-334 [-1, 1045, 17, 12] 0\n",
" Conv2d-335 [-1, 170, 17, 12] 177,650\n",
" BatchNorm2d-336 [-1, 170, 17, 12] 340\n",
" BasicConv2d-337 [-1, 170, 17, 12] 0\n",
" ZeroPad2d-338 [-1, 170, 19, 14] 0\n",
" Conv2d-339 [-1, 256, 17, 12] 391,680\n",
" BatchNorm2d-340 [-1, 256, 17, 12] 512\n",
" BasicConv2d-341 [-1, 256, 17, 12] 0\n",
" ZeroPad2d-342 [-1, 256, 19, 14] 0\n",
" Conv2d-343 [-1, 512, 17, 12] 1,179,648\n",
" BatchNorm2d-344 [-1, 512, 17, 12] 1,024\n",
" BasicConv2d-345 [-1, 512, 17, 12] 0\n",
" BatchNorm2d-346 [-1, 1194, 17, 12] 2,388\n",
"Multi_2D_CNN_block-347 [-1, 1194, 17, 12] 0\n",
" ZeroPad2d-348 [-1, 1194, 17, 12] 0\n",
" Conv2d-349 [-1, 170, 17, 12] 202,980\n",
" BatchNorm2d-350 [-1, 170, 17, 12] 340\n",
" BasicConv2d-351 [-1, 170, 17, 12] 0\n",
" ZeroPad2d-352 [-1, 1194, 17, 12] 0\n",
" Conv2d-353 [-1, 256, 17, 12] 305,664\n",
" BatchNorm2d-354 [-1, 256, 17, 12] 512\n",
" BasicConv2d-355 [-1, 256, 17, 12] 0\n",
" ZeroPad2d-356 [-1, 256, 19, 14] 0\n",
" Conv2d-357 [-1, 512, 17, 12] 1,179,648\n",
" BatchNorm2d-358 [-1, 512, 17, 12] 1,024\n",
" BasicConv2d-359 [-1, 512, 17, 12] 0\n",
" ZeroPad2d-360 [-1, 1194, 17, 12] 0\n",
" Conv2d-361 [-1, 170, 17, 12] 202,980\n",
" BatchNorm2d-362 [-1, 170, 17, 12] 340\n",
" BasicConv2d-363 [-1, 170, 17, 12] 0\n",
" ZeroPad2d-364 [-1, 170, 19, 14] 0\n",
" Conv2d-365 [-1, 256, 17, 12] 391,680\n",
" BatchNorm2d-366 [-1, 256, 17, 12] 512\n",
" BasicConv2d-367 [-1, 256, 17, 12] 0\n",
" ZeroPad2d-368 [-1, 256, 19, 14] 0\n",
" Conv2d-369 [-1, 512, 17, 12] 1,179,648\n",
" BatchNorm2d-370 [-1, 512, 17, 12] 1,024\n",
" BasicConv2d-371 [-1, 512, 17, 12] 0\n",
" BatchNorm2d-372 [-1, 1194, 17, 12] 2,388\n",
"Multi_2D_CNN_block-373 [-1, 1194, 17, 12] 0\n",
" ZeroPad2d-374 [-1, 1194, 17, 12] 0\n",
" Conv2d-375 [-1, 170, 17, 12] 202,980\n",
" BatchNorm2d-376 [-1, 170, 17, 12] 340\n",
" BasicConv2d-377 [-1, 170, 17, 12] 0\n",
" ZeroPad2d-378 [-1, 1194, 17, 12] 0\n",
" Conv2d-379 [-1, 256, 17, 12] 305,664\n",
" BatchNorm2d-380 [-1, 256, 17, 12] 512\n",
" BasicConv2d-381 [-1, 256, 17, 12] 0\n",
" ZeroPad2d-382 [-1, 256, 19, 14] 0\n",
" Conv2d-383 [-1, 512, 17, 12] 1,179,648\n",
" BatchNorm2d-384 [-1, 512, 17, 12] 1,024\n",
" BasicConv2d-385 [-1, 512, 17, 12] 0\n",
" ZeroPad2d-386 [-1, 1194, 17, 12] 0\n",
" Conv2d-387 [-1, 170, 17, 12] 202,980\n",
" BatchNorm2d-388 [-1, 170, 17, 12] 340\n",
" BasicConv2d-389 [-1, 170, 17, 12] 0\n",
" ZeroPad2d-390 [-1, 170, 19, 14] 0\n",
" Conv2d-391 [-1, 256, 17, 12] 391,680\n",
" BatchNorm2d-392 [-1, 256, 17, 12] 512\n",
" BasicConv2d-393 [-1, 256, 17, 12] 0\n",
" ZeroPad2d-394 [-1, 256, 19, 14] 0\n",
" Conv2d-395 [-1, 512, 17, 12] 1,179,648\n",
" BatchNorm2d-396 [-1, 512, 17, 12] 1,024\n",
" BasicConv2d-397 [-1, 512, 17, 12] 0\n",
" BatchNorm2d-398 [-1, 1194, 17, 12] 2,388\n",
"Multi_2D_CNN_block-399 [-1, 1194, 17, 12] 0\n",
" MaxPool2d-400 [-1, 1194, 8, 12] 0\n",
" ZeroPad2d-401 [-1, 1194, 8, 12] 0\n",
" Conv2d-402 [-1, 256, 8, 12] 305,664\n",
" BatchNorm2d-403 [-1, 256, 8, 12] 512\n",
" BasicConv2d-404 [-1, 256, 8, 12] 0\n",
" ZeroPad2d-405 [-1, 1194, 8, 12] 0\n",
" Conv2d-406 [-1, 384, 8, 12] 458,496\n",
" BatchNorm2d-407 [-1, 384, 8, 12] 768\n",
" BasicConv2d-408 [-1, 384, 8, 12] 0\n",
" ZeroPad2d-409 [-1, 384, 10, 14] 0\n",
" Conv2d-410 [-1, 768, 8, 12] 2,654,208\n",
" BatchNorm2d-411 [-1, 768, 8, 12] 1,536\n",
" BasicConv2d-412 [-1, 768, 8, 12] 0\n",
" ZeroPad2d-413 [-1, 1194, 8, 12] 0\n",
" Conv2d-414 [-1, 256, 8, 12] 305,664\n",
" BatchNorm2d-415 [-1, 256, 8, 12] 512\n",
" BasicConv2d-416 [-1, 256, 8, 12] 0\n",
" ZeroPad2d-417 [-1, 256, 10, 14] 0\n",
" Conv2d-418 [-1, 384, 8, 12] 884,736\n",
" BatchNorm2d-419 [-1, 384, 8, 12] 768\n",
" BasicConv2d-420 [-1, 384, 8, 12] 0\n",
" ZeroPad2d-421 [-1, 384, 10, 14] 0\n",
" Conv2d-422 [-1, 768, 8, 12] 2,654,208\n",
" BatchNorm2d-423 [-1, 768, 8, 12] 1,536\n",
" BasicConv2d-424 [-1, 768, 8, 12] 0\n",
" BatchNorm2d-425 [-1, 1792, 8, 12] 3,584\n",
"Multi_2D_CNN_block-426 [-1, 1792, 8, 12] 0\n",
" ZeroPad2d-427 [-1, 1792, 8, 12] 0\n",
" Conv2d-428 [-1, 298, 8, 12] 534,016\n",
" BatchNorm2d-429 [-1, 298, 8, 12] 596\n",
" BasicConv2d-430 [-1, 298, 8, 12] 0\n",
" ZeroPad2d-431 [-1, 1792, 8, 12] 0\n",
" Conv2d-432 [-1, 448, 8, 12] 802,816\n",
" BatchNorm2d-433 [-1, 448, 8, 12] 896\n",
" BasicConv2d-434 [-1, 448, 8, 12] 0\n",
" ZeroPad2d-435 [-1, 448, 10, 14] 0\n",
" Conv2d-436 [-1, 896, 8, 12] 3,612,672\n",
" BatchNorm2d-437 [-1, 896, 8, 12] 1,792\n",
" BasicConv2d-438 [-1, 896, 8, 12] 0\n",
" ZeroPad2d-439 [-1, 1792, 8, 12] 0\n",
" Conv2d-440 [-1, 298, 8, 12] 534,016\n",
" BatchNorm2d-441 [-1, 298, 8, 12] 596\n",
" BasicConv2d-442 [-1, 298, 8, 12] 0\n",
" ZeroPad2d-443 [-1, 298, 10, 14] 0\n",
" Conv2d-444 [-1, 448, 8, 12] 1,201,536\n",
" BatchNorm2d-445 [-1, 448, 8, 12] 896\n",
" BasicConv2d-446 [-1, 448, 8, 12] 0\n",
" ZeroPad2d-447 [-1, 448, 10, 14] 0\n",
" Conv2d-448 [-1, 896, 8, 12] 3,612,672\n",
" BatchNorm2d-449 [-1, 896, 8, 12] 1,792\n",
" BasicConv2d-450 [-1, 896, 8, 12] 0\n",
" BatchNorm2d-451 [-1, 2090, 8, 12] 4,180\n",
"Multi_2D_CNN_block-452 [-1, 2090, 8, 12] 0\n",
" ZeroPad2d-453 [-1, 2090, 8, 12] 0\n",
" Conv2d-454 [-1, 341, 8, 12] 712,690\n",
" BatchNorm2d-455 [-1, 341, 8, 12] 682\n",
" BasicConv2d-456 [-1, 341, 8, 12] 0\n",
" ZeroPad2d-457 [-1, 2090, 8, 12] 0\n",
" Conv2d-458 [-1, 512, 8, 12] 1,070,080\n",
" BatchNorm2d-459 [-1, 512, 8, 12] 1,024\n",
" BasicConv2d-460 [-1, 512, 8, 12] 0\n",
" ZeroPad2d-461 [-1, 512, 10, 14] 0\n",
" Conv2d-462 [-1, 1024, 8, 12] 4,718,592\n",
" BatchNorm2d-463 [-1, 1024, 8, 12] 2,048\n",
" BasicConv2d-464 [-1, 1024, 8, 12] 0\n",
" ZeroPad2d-465 [-1, 2090, 8, 12] 0\n",
" Conv2d-466 [-1, 341, 8, 12] 712,690\n",
" BatchNorm2d-467 [-1, 341, 8, 12] 682\n",
" BasicConv2d-468 [-1, 341, 8, 12] 0\n",
" ZeroPad2d-469 [-1, 341, 10, 14] 0\n",
" Conv2d-470 [-1, 512, 8, 12] 1,571,328\n",
" BatchNorm2d-471 [-1, 512, 8, 12] 1,024\n",
" BasicConv2d-472 [-1, 512, 8, 12] 0\n",
" ZeroPad2d-473 [-1, 512, 10, 14] 0\n",
" Conv2d-474 [-1, 1024, 8, 12] 4,718,592\n",
" BatchNorm2d-475 [-1, 1024, 8, 12] 2,048\n",
" BasicConv2d-476 [-1, 1024, 8, 12] 0\n",
" BatchNorm2d-477 [-1, 2389, 8, 12] 4,778\n",
"Multi_2D_CNN_block-478 [-1, 2389, 8, 12] 0\n",
"AdaptiveAvgPool2d-479 [-1, 2389, 1, 1] 0\n",
" Flatten-480 [-1, 2389] 0\n",
" Dropout-481 [-1, 2389] 0\n",
" Linear-482 [-1, 1] 2,390\n",
"================================================================\n",
"Total params: 49,719,798\n",
"Trainable params: 49,719,798\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n",
"Input size (MB): 0.11\n",
"Forward/backward pass size (MB): 884.62\n",
"Params size (MB): 189.67\n",
"Estimated Total Size (MB): 1074.40\n",
"----------------------------------------------------------------\n"
]
}
],
"source": [
"summary(model, input_size=(1, 2500, 12))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" ZeroPad2d-1 [-1, 1, 2506, 14] 0\n",
" Conv2d-2 [-1, 64, 1250, 12] 1,344\n",
" BatchNorm2d-3 [-1, 64, 1250, 12] 128\n",
" BasicConv2d-4 [-1, 64, 1250, 12] 0\n",
" ZeroPad2d-5 [-1, 64, 1250, 12] 0\n",
" Conv2d-6 [-1, 21, 1250, 12] 1,344\n",
" BatchNorm2d-7 [-1, 21, 1250, 12] 42\n",
" BasicConv2d-8 [-1, 21, 1250, 12] 0\n",
" ZeroPad2d-9 [-1, 64, 1250, 12] 0\n",
" Conv2d-10 [-1, 32, 1250, 12] 2,048\n",
" BatchNorm2d-11 [-1, 32, 1250, 12] 64\n",
" BasicConv2d-12 [-1, 32, 1250, 12] 0\n",
" ZeroPad2d-13 [-1, 32, 1252, 14] 0\n",
" Conv2d-14 [-1, 64, 1250, 12] 18,432\n",
" BatchNorm2d-15 [-1, 64, 1250, 12] 128\n",
" BasicConv2d-16 [-1, 64, 1250, 12] 0\n",
" ZeroPad2d-17 [-1, 64, 1250, 12] 0\n",
" Conv2d-18 [-1, 21, 1250, 12] 1,344\n",
" BatchNorm2d-19 [-1, 21, 1250, 12] 42\n",
" BasicConv2d-20 [-1, 21, 1250, 12] 0\n",
" ZeroPad2d-21 [-1, 21, 1252, 14] 0\n",
" Conv2d-22 [-1, 32, 1250, 12] 6,048\n",
" BatchNorm2d-23 [-1, 32, 1250, 12] 64\n",
" BasicConv2d-24 [-1, 32, 1250, 12] 0\n",
" ZeroPad2d-25 [-1, 32, 1252, 14] 0\n",
" Conv2d-26 [-1, 64, 1250, 12] 18,432\n",
" BatchNorm2d-27 [-1, 64, 1250, 12] 128\n",
" BasicConv2d-28 [-1, 64, 1250, 12] 0\n",
" BatchNorm2d-29 [-1, 149, 1250, 12] 298\n",
"Multi_2D_CNN_block-30 [-1, 149, 1250, 12] 0\n",
" ZeroPad2d-31 [-1, 149, 1250, 12] 0\n",
" Conv2d-32 [-1, 21, 1250, 12] 3,129\n",
" BatchNorm2d-33 [-1, 21, 1250, 12] 42\n",
" BasicConv2d-34 [-1, 21, 1250, 12] 0\n",
" ZeroPad2d-35 [-1, 149, 1250, 12] 0\n",
" Conv2d-36 [-1, 32, 1250, 12] 4,768\n",
" BatchNorm2d-37 [-1, 32, 1250, 12] 64\n",
" BasicConv2d-38 [-1, 32, 1250, 12] 0\n",
" ZeroPad2d-39 [-1, 32, 1252, 14] 0\n",
" Conv2d-40 [-1, 64, 1250, 12] 18,432\n",
" BatchNorm2d-41 [-1, 64, 1250, 12] 128\n",
" BasicConv2d-42 [-1, 64, 1250, 12] 0\n",
" ZeroPad2d-43 [-1, 149, 1250, 12] 0\n",
" Conv2d-44 [-1, 21, 1250, 12] 3,129\n",
" BatchNorm2d-45 [-1, 21, 1250, 12] 42\n",
" BasicConv2d-46 [-1, 21, 1250, 12] 0\n",
" ZeroPad2d-47 [-1, 21, 1252, 14] 0\n",
" Conv2d-48 [-1, 32, 1250, 12] 6,048\n",
" BatchNorm2d-49 [-1, 32, 1250, 12] 64\n",
" BasicConv2d-50 [-1, 32, 1250, 12] 0\n",
" ZeroPad2d-51 [-1, 32, 1252, 14] 0\n",
" Conv2d-52 [-1, 64, 1250, 12] 18,432\n",
" BatchNorm2d-53 [-1, 64, 1250, 12] 128\n",
" BasicConv2d-54 [-1, 64, 1250, 12] 0\n",
" BatchNorm2d-55 [-1, 149, 1250, 12] 298\n",
"Multi_2D_CNN_block-56 [-1, 149, 1250, 12] 0\n",
" MaxPool2d-57 [-1, 149, 416, 12] 0\n",
" ZeroPad2d-58 [-1, 149, 416, 12] 0\n",
" Conv2d-59 [-1, 32, 416, 12] 4,768\n",
" BatchNorm2d-60 [-1, 32, 416, 12] 64\n",
" BasicConv2d-61 [-1, 32, 416, 12] 0\n",
" ZeroPad2d-62 [-1, 149, 416, 12] 0\n",
" Conv2d-63 [-1, 48, 416, 12] 7,152\n",
" BatchNorm2d-64 [-1, 48, 416, 12] 96\n",
" BasicConv2d-65 [-1, 48, 416, 12] 0\n",
" ZeroPad2d-66 [-1, 48, 418, 14] 0\n",
" Conv2d-67 [-1, 96, 416, 12] 41,472\n",
" BatchNorm2d-68 [-1, 96, 416, 12] 192\n",
" BasicConv2d-69 [-1, 96, 416, 12] 0\n",
" ZeroPad2d-70 [-1, 149, 416, 12] 0\n",
" Conv2d-71 [-1, 32, 416, 12] 4,768\n",
" BatchNorm2d-72 [-1, 32, 416, 12] 64\n",
" BasicConv2d-73 [-1, 32, 416, 12] 0\n",
" ZeroPad2d-74 [-1, 32, 418, 14] 0\n",
" Conv2d-75 [-1, 48, 416, 12] 13,824\n",
" BatchNorm2d-76 [-1, 48, 416, 12] 96\n",
" BasicConv2d-77 [-1, 48, 416, 12] 0\n",
" ZeroPad2d-78 [-1, 48, 418, 14] 0\n",
" Conv2d-79 [-1, 96, 416, 12] 41,472\n",
" BatchNorm2d-80 [-1, 96, 416, 12] 192\n",
" BasicConv2d-81 [-1, 96, 416, 12] 0\n",
" BatchNorm2d-82 [-1, 224, 416, 12] 448\n",
"Multi_2D_CNN_block-83 [-1, 224, 416, 12] 0\n",
" ZeroPad2d-84 [-1, 224, 416, 12] 0\n",
" Conv2d-85 [-1, 32, 416, 12] 7,168\n",
" BatchNorm2d-86 [-1, 32, 416, 12] 64\n",
" BasicConv2d-87 [-1, 32, 416, 12] 0\n",
" ZeroPad2d-88 [-1, 224, 416, 12] 0\n",
" Conv2d-89 [-1, 48, 416, 12] 10,752\n",
" BatchNorm2d-90 [-1, 48, 416, 12] 96\n",
" BasicConv2d-91 [-1, 48, 416, 12] 0\n",
" ZeroPad2d-92 [-1, 48, 418, 14] 0\n",
" Conv2d-93 [-1, 96, 416, 12] 41,472\n",
" BatchNorm2d-94 [-1, 96, 416, 12] 192\n",
" BasicConv2d-95 [-1, 96, 416, 12] 0\n",
" ZeroPad2d-96 [-1, 224, 416, 12] 0\n",
" Conv2d-97 [-1, 32, 416, 12] 7,168\n",
" BatchNorm2d-98 [-1, 32, 416, 12] 64\n",
" BasicConv2d-99 [-1, 32, 416, 12] 0\n",
" ZeroPad2d-100 [-1, 32, 418, 14] 0\n",
" Conv2d-101 [-1, 48, 416, 12] 13,824\n",
" BatchNorm2d-102 [-1, 48, 416, 12] 96\n",
" BasicConv2d-103 [-1, 48, 416, 12] 0\n",
" ZeroPad2d-104 [-1, 48, 418, 14] 0\n",
" Conv2d-105 [-1, 96, 416, 12] 41,472\n",
" BatchNorm2d-106 [-1, 96, 416, 12] 192\n",
" BasicConv2d-107 [-1, 96, 416, 12] 0\n",
" BatchNorm2d-108 [-1, 224, 416, 12] 448\n",
"Multi_2D_CNN_block-109 [-1, 224, 416, 12] 0\n",
" MaxPool2d-110 [-1, 224, 138, 12] 0\n",
" ZeroPad2d-111 [-1, 224, 138, 12] 0\n",
" Conv2d-112 [-1, 42, 138, 12] 9,408\n",
" BatchNorm2d-113 [-1, 42, 138, 12] 84\n",
" BasicConv2d-114 [-1, 42, 138, 12] 0\n",
" ZeroPad2d-115 [-1, 224, 138, 12] 0\n",
" Conv2d-116 [-1, 64, 138, 12] 14,336\n",
" BatchNorm2d-117 [-1, 64, 138, 12] 128\n",
" BasicConv2d-118 [-1, 64, 138, 12] 0\n",
" ZeroPad2d-119 [-1, 64, 140, 14] 0\n",
" Conv2d-120 [-1, 128, 138, 12] 73,728\n",
" BatchNorm2d-121 [-1, 128, 138, 12] 256\n",
" BasicConv2d-122 [-1, 128, 138, 12] 0\n",
" ZeroPad2d-123 [-1, 224, 138, 12] 0\n",
" Conv2d-124 [-1, 42, 138, 12] 9,408\n",
" BatchNorm2d-125 [-1, 42, 138, 12] 84\n",
" BasicConv2d-126 [-1, 42, 138, 12] 0\n",
" ZeroPad2d-127 [-1, 42, 140, 14] 0\n",
" Conv2d-128 [-1, 64, 138, 12] 24,192\n",
" BatchNorm2d-129 [-1, 64, 138, 12] 128\n",
" BasicConv2d-130 [-1, 64, 138, 12] 0\n",
" ZeroPad2d-131 [-1, 64, 140, 14] 0\n",
" Conv2d-132 [-1, 128, 138, 12] 73,728\n",
" BatchNorm2d-133 [-1, 128, 138, 12] 256\n",
" BasicConv2d-134 [-1, 128, 138, 12] 0\n",
" BatchNorm2d-135 [-1, 298, 138, 12] 596\n",
"Multi_2D_CNN_block-136 [-1, 298, 138, 12] 0\n",
" ZeroPad2d-137 [-1, 298, 138, 12] 0\n",
" Conv2d-138 [-1, 42, 138, 12] 12,516\n",
" BatchNorm2d-139 [-1, 42, 138, 12] 84\n",
" BasicConv2d-140 [-1, 42, 138, 12] 0\n",
" ZeroPad2d-141 [-1, 298, 138, 12] 0\n",
" Conv2d-142 [-1, 64, 138, 12] 19,072\n",
" BatchNorm2d-143 [-1, 64, 138, 12] 128\n",
" BasicConv2d-144 [-1, 64, 138, 12] 0\n",
" ZeroPad2d-145 [-1, 64, 140, 14] 0\n",
" Conv2d-146 [-1, 128, 138, 12] 73,728\n",
" BatchNorm2d-147 [-1, 128, 138, 12] 256\n",
" BasicConv2d-148 [-1, 128, 138, 12] 0\n",
" ZeroPad2d-149 [-1, 298, 138, 12] 0\n",
" Conv2d-150 [-1, 42, 138, 12] 12,516\n",
" BatchNorm2d-151 [-1, 42, 138, 12] 84\n",
" BasicConv2d-152 [-1, 42, 138, 12] 0\n",
" ZeroPad2d-153 [-1, 42, 140, 14] 0\n",
" Conv2d-154 [-1, 64, 138, 12] 24,192\n",
" BatchNorm2d-155 [-1, 64, 138, 12] 128\n",
" BasicConv2d-156 [-1, 64, 138, 12] 0\n",
" ZeroPad2d-157 [-1, 64, 140, 14] 0\n",
" Conv2d-158 [-1, 128, 138, 12] 73,728\n",
" BatchNorm2d-159 [-1, 128, 138, 12] 256\n",
" BasicConv2d-160 [-1, 128, 138, 12] 0\n",
" BatchNorm2d-161 [-1, 298, 138, 12] 596\n",
"Multi_2D_CNN_block-162 [-1, 298, 138, 12] 0\n",
" MaxPool2d-163 [-1, 298, 69, 12] 0\n",
" ZeroPad2d-164 [-1, 298, 69, 12] 0\n",
" Conv2d-165 [-1, 64, 69, 12] 19,072\n",
" BatchNorm2d-166 [-1, 64, 69, 12] 128\n",
" BasicConv2d-167 [-1, 64, 69, 12] 0\n",
" ZeroPad2d-168 [-1, 298, 69, 12] 0\n",
" Conv2d-169 [-1, 96, 69, 12] 28,608\n",
" BatchNorm2d-170 [-1, 96, 69, 12] 192\n",
" BasicConv2d-171 [-1, 96, 69, 12] 0\n",
" ZeroPad2d-172 [-1, 96, 71, 14] 0\n",
" Conv2d-173 [-1, 192, 69, 12] 165,888\n",
" BatchNorm2d-174 [-1, 192, 69, 12] 384\n",
" BasicConv2d-175 [-1, 192, 69, 12] 0\n",
" ZeroPad2d-176 [-1, 298, 69, 12] 0\n",
" Conv2d-177 [-1, 64, 69, 12] 19,072\n",
" BatchNorm2d-178 [-1, 64, 69, 12] 128\n",
" BasicConv2d-179 [-1, 64, 69, 12] 0\n",
" ZeroPad2d-180 [-1, 64, 71, 14] 0\n",
" Conv2d-181 [-1, 96, 69, 12] 55,296\n",
" BatchNorm2d-182 [-1, 96, 69, 12] 192\n",
" BasicConv2d-183 [-1, 96, 69, 12] 0\n",
" ZeroPad2d-184 [-1, 96, 71, 14] 0\n",
" Conv2d-185 [-1, 192, 69, 12] 165,888\n",
" BatchNorm2d-186 [-1, 192, 69, 12] 384\n",
" BasicConv2d-187 [-1, 192, 69, 12] 0\n",
" BatchNorm2d-188 [-1, 448, 69, 12] 896\n",
"Multi_2D_CNN_block-189 [-1, 448, 69, 12] 0\n",
" ZeroPad2d-190 [-1, 448, 69, 12] 0\n",
" Conv2d-191 [-1, 64, 69, 12] 28,672\n",
" BatchNorm2d-192 [-1, 64, 69, 12] 128\n",
" BasicConv2d-193 [-1, 64, 69, 12] 0\n",
" ZeroPad2d-194 [-1, 448, 69, 12] 0\n",
" Conv2d-195 [-1, 96, 69, 12] 43,008\n",
" BatchNorm2d-196 [-1, 96, 69, 12] 192\n",
" BasicConv2d-197 [-1, 96, 69, 12] 0\n",
" ZeroPad2d-198 [-1, 96, 71, 14] 0\n",
" Conv2d-199 [-1, 192, 69, 12] 165,888\n",
" BatchNorm2d-200 [-1, 192, 69, 12] 384\n",
" BasicConv2d-201 [-1, 192, 69, 12] 0\n",
" ZeroPad2d-202 [-1, 448, 69, 12] 0\n",
" Conv2d-203 [-1, 64, 69, 12] 28,672\n",
" BatchNorm2d-204 [-1, 64, 69, 12] 128\n",
" BasicConv2d-205 [-1, 64, 69, 12] 0\n",
" ZeroPad2d-206 [-1, 64, 71, 14] 0\n",
" Conv2d-207 [-1, 96, 69, 12] 55,296\n",
" BatchNorm2d-208 [-1, 96, 69, 12] 192\n",
" BasicConv2d-209 [-1, 96, 69, 12] 0\n",
" ZeroPad2d-210 [-1, 96, 71, 14] 0\n",
" Conv2d-211 [-1, 192, 69, 12] 165,888\n",
" BatchNorm2d-212 [-1, 192, 69, 12] 384\n",
" BasicConv2d-213 [-1, 192, 69, 12] 0\n",
" BatchNorm2d-214 [-1, 448, 69, 12] 896\n",
"Multi_2D_CNN_block-215 [-1, 448, 69, 12] 0\n",
" ZeroPad2d-216 [-1, 448, 69, 12] 0\n",
" Conv2d-217 [-1, 85, 69, 12] 38,080\n",
" BatchNorm2d-218 [-1, 85, 69, 12] 170\n",
" BasicConv2d-219 [-1, 85, 69, 12] 0\n",
" ZeroPad2d-220 [-1, 448, 69, 12] 0\n",
" Conv2d-221 [-1, 128, 69, 12] 57,344\n",
" BatchNorm2d-222 [-1, 128, 69, 12] 256\n",
" BasicConv2d-223 [-1, 128, 69, 12] 0\n",
" ZeroPad2d-224 [-1, 128, 71, 14] 0\n",
" Conv2d-225 [-1, 256, 69, 12] 294,912\n",
" BatchNorm2d-226 [-1, 256, 69, 12] 512\n",
" BasicConv2d-227 [-1, 256, 69, 12] 0\n",
" ZeroPad2d-228 [-1, 448, 69, 12] 0\n",
" Conv2d-229 [-1, 85, 69, 12] 38,080\n",
" BatchNorm2d-230 [-1, 85, 69, 12] 170\n",
" BasicConv2d-231 [-1, 85, 69, 12] 0\n",
" ZeroPad2d-232 [-1, 85, 71, 14] 0\n",
" Conv2d-233 [-1, 128, 69, 12] 97,920\n",
" BatchNorm2d-234 [-1, 128, 69, 12] 256\n",
" BasicConv2d-235 [-1, 128, 69, 12] 0\n",
" ZeroPad2d-236 [-1, 128, 71, 14] 0\n",
" Conv2d-237 [-1, 256, 69, 12] 294,912\n",
" BatchNorm2d-238 [-1, 256, 69, 12] 512\n",
" BasicConv2d-239 [-1, 256, 69, 12] 0\n",
" BatchNorm2d-240 [-1, 597, 69, 12] 1,194\n",
"Multi_2D_CNN_block-241 [-1, 597, 69, 12] 0\n",
" MaxPool2d-242 [-1, 597, 34, 12] 0\n",
" ZeroPad2d-243 [-1, 597, 34, 12] 0\n",
" Conv2d-244 [-1, 106, 34, 12] 63,282\n",
" BatchNorm2d-245 [-1, 106, 34, 12] 212\n",
" BasicConv2d-246 [-1, 106, 34, 12] 0\n",
" ZeroPad2d-247 [-1, 597, 34, 12] 0\n",
" Conv2d-248 [-1, 160, 34, 12] 95,520\n",
" BatchNorm2d-249 [-1, 160, 34, 12] 320\n",
" BasicConv2d-250 [-1, 160, 34, 12] 0\n",
" ZeroPad2d-251 [-1, 160, 36, 14] 0\n",
" Conv2d-252 [-1, 320, 34, 12] 460,800\n",
" BatchNorm2d-253 [-1, 320, 34, 12] 640\n",
" BasicConv2d-254 [-1, 320, 34, 12] 0\n",
" ZeroPad2d-255 [-1, 597, 34, 12] 0\n",
" Conv2d-256 [-1, 106, 34, 12] 63,282\n",
" BatchNorm2d-257 [-1, 106, 34, 12] 212\n",
" BasicConv2d-258 [-1, 106, 34, 12] 0\n",
" ZeroPad2d-259 [-1, 106, 36, 14] 0\n",
" Conv2d-260 [-1, 160, 34, 12] 152,640\n",
" BatchNorm2d-261 [-1, 160, 34, 12] 320\n",
" BasicConv2d-262 [-1, 160, 34, 12] 0\n",
" ZeroPad2d-263 [-1, 160, 36, 14] 0\n",
" Conv2d-264 [-1, 320, 34, 12] 460,800\n",
" BatchNorm2d-265 [-1, 320, 34, 12] 640\n",
" BasicConv2d-266 [-1, 320, 34, 12] 0\n",
" BatchNorm2d-267 [-1, 746, 34, 12] 1,492\n",
"Multi_2D_CNN_block-268 [-1, 746, 34, 12] 0\n",
" ZeroPad2d-269 [-1, 746, 34, 12] 0\n",
" Conv2d-270 [-1, 128, 34, 12] 95,488\n",
" BatchNorm2d-271 [-1, 128, 34, 12] 256\n",
" BasicConv2d-272 [-1, 128, 34, 12] 0\n",
" ZeroPad2d-273 [-1, 746, 34, 12] 0\n",
" Conv2d-274 [-1, 192, 34, 12] 143,232\n",
" BatchNorm2d-275 [-1, 192, 34, 12] 384\n",
" BasicConv2d-276 [-1, 192, 34, 12] 0\n",
" ZeroPad2d-277 [-1, 192, 36, 14] 0\n",
" Conv2d-278 [-1, 384, 34, 12] 663,552\n",
" BatchNorm2d-279 [-1, 384, 34, 12] 768\n",
" BasicConv2d-280 [-1, 384, 34, 12] 0\n",
" ZeroPad2d-281 [-1, 746, 34, 12] 0\n",
" Conv2d-282 [-1, 128, 34, 12] 95,488\n",
" BatchNorm2d-283 [-1, 128, 34, 12] 256\n",
" BasicConv2d-284 [-1, 128, 34, 12] 0\n",
" ZeroPad2d-285 [-1, 128, 36, 14] 0\n",
" Conv2d-286 [-1, 192, 34, 12] 221,184\n",
" BatchNorm2d-287 [-1, 192, 34, 12] 384\n",
" BasicConv2d-288 [-1, 192, 34, 12] 0\n",
" ZeroPad2d-289 [-1, 192, 36, 14] 0\n",
" Conv2d-290 [-1, 384, 34, 12] 663,552\n",
" BatchNorm2d-291 [-1, 384, 34, 12] 768\n",
" BasicConv2d-292 [-1, 384, 34, 12] 0\n",
" BatchNorm2d-293 [-1, 896, 34, 12] 1,792\n",
"Multi_2D_CNN_block-294 [-1, 896, 34, 12] 0\n",
" ZeroPad2d-295 [-1, 896, 34, 12] 0\n",
" Conv2d-296 [-1, 149, 34, 12] 133,504\n",
" BatchNorm2d-297 [-1, 149, 34, 12] 298\n",
" BasicConv2d-298 [-1, 149, 34, 12] 0\n",
" ZeroPad2d-299 [-1, 896, 34, 12] 0\n",
" Conv2d-300 [-1, 224, 34, 12] 200,704\n",
" BatchNorm2d-301 [-1, 224, 34, 12] 448\n",
" BasicConv2d-302 [-1, 224, 34, 12] 0\n",
" ZeroPad2d-303 [-1, 224, 36, 14] 0\n",
" Conv2d-304 [-1, 448, 34, 12] 903,168\n",
" BatchNorm2d-305 [-1, 448, 34, 12] 896\n",
" BasicConv2d-306 [-1, 448, 34, 12] 0\n",
" ZeroPad2d-307 [-1, 896, 34, 12] 0\n",
" Conv2d-308 [-1, 149, 34, 12] 133,504\n",
" BatchNorm2d-309 [-1, 149, 34, 12] 298\n",
" BasicConv2d-310 [-1, 149, 34, 12] 0\n",
" ZeroPad2d-311 [-1, 149, 36, 14] 0\n",
" Conv2d-312 [-1, 224, 34, 12] 300,384\n",
" BatchNorm2d-313 [-1, 224, 34, 12] 448\n",
" BasicConv2d-314 [-1, 224, 34, 12] 0\n",
" ZeroPad2d-315 [-1, 224, 36, 14] 0\n",
" Conv2d-316 [-1, 448, 34, 12] 903,168\n",
" BatchNorm2d-317 [-1, 448, 34, 12] 896\n",
" BasicConv2d-318 [-1, 448, 34, 12] 0\n",
" BatchNorm2d-319 [-1, 1045, 34, 12] 2,090\n",
"Multi_2D_CNN_block-320 [-1, 1045, 34, 12] 0\n",
" MaxPool2d-321 [-1, 1045, 17, 12] 0\n",
" ZeroPad2d-322 [-1, 1045, 17, 12] 0\n",
" Conv2d-323 [-1, 170, 17, 12] 177,650\n",
" BatchNorm2d-324 [-1, 170, 17, 12] 340\n",
" BasicConv2d-325 [-1, 170, 17, 12] 0\n",
" ZeroPad2d-326 [-1, 1045, 17, 12] 0\n",
" Conv2d-327 [-1, 256, 17, 12] 267,520\n",
" BatchNorm2d-328 [-1, 256, 17, 12] 512\n",
" BasicConv2d-329 [-1, 256, 17, 12] 0\n",
" ZeroPad2d-330 [-1, 256, 19, 14] 0\n",
" Conv2d-331 [-1, 512, 17, 12] 1,179,648\n",
" BatchNorm2d-332 [-1, 512, 17, 12] 1,024\n",
" BasicConv2d-333 [-1, 512, 17, 12] 0\n",
" ZeroPad2d-334 [-1, 1045, 17, 12] 0\n",
" Conv2d-335 [-1, 170, 17, 12] 177,650\n",
" BatchNorm2d-336 [-1, 170, 17, 12] 340\n",
" BasicConv2d-337 [-1, 170, 17, 12] 0\n",
" ZeroPad2d-338 [-1, 170, 19, 14] 0\n",
" Conv2d-339 [-1, 256, 17, 12] 391,680\n",
" BatchNorm2d-340 [-1, 256, 17, 12] 512\n",
" BasicConv2d-341 [-1, 256, 17, 12] 0\n",
" ZeroPad2d-342 [-1, 256, 19, 14] 0\n",
" Conv2d-343 [-1, 512, 17, 12] 1,179,648\n",
" BatchNorm2d-344 [-1, 512, 17, 12] 1,024\n",
" BasicConv2d-345 [-1, 512, 17, 12] 0\n",
" BatchNorm2d-346 [-1, 1194, 17, 12] 2,388\n",
"Multi_2D_CNN_block-347 [-1, 1194, 17, 12] 0\n",
" ZeroPad2d-348 [-1, 1194, 17, 12] 0\n",
" Conv2d-349 [-1, 170, 17, 12] 202,980\n",
" BatchNorm2d-350 [-1, 170, 17, 12] 340\n",
" BasicConv2d-351 [-1, 170, 17, 12] 0\n",
" ZeroPad2d-352 [-1, 1194, 17, 12] 0\n",
" Conv2d-353 [-1, 256, 17, 12] 305,664\n",
" BatchNorm2d-354 [-1, 256, 17, 12] 512\n",
" BasicConv2d-355 [-1, 256, 17, 12] 0\n",
" ZeroPad2d-356 [-1, 256, 19, 14] 0\n",
" Conv2d-357 [-1, 512, 17, 12] 1,179,648\n",
" BatchNorm2d-358 [-1, 512, 17, 12] 1,024\n",
" BasicConv2d-359 [-1, 512, 17, 12] 0\n",
" ZeroPad2d-360 [-1, 1194, 17, 12] 0\n",
" Conv2d-361 [-1, 170, 17, 12] 202,980\n",
" BatchNorm2d-362 [-1, 170, 17, 12] 340\n",
" BasicConv2d-363 [-1, 170, 17, 12] 0\n",
" ZeroPad2d-364 [-1, 170, 19, 14] 0\n",
" Conv2d-365 [-1, 256, 17, 12] 391,680\n",
" BatchNorm2d-366 [-1, 256, 17, 12] 512\n",
" BasicConv2d-367 [-1, 256, 17, 12] 0\n",
" ZeroPad2d-368 [-1, 256, 19, 14] 0\n",
" Conv2d-369 [-1, 512, 17, 12] 1,179,648\n",
" BatchNorm2d-370 [-1, 512, 17, 12] 1,024\n",
" BasicConv2d-371 [-1, 512, 17, 12] 0\n",
" BatchNorm2d-372 [-1, 1194, 17, 12] 2,388\n",
"Multi_2D_CNN_block-373 [-1, 1194, 17, 12] 0\n",
" ZeroPad2d-374 [-1, 1194, 17, 12] 0\n",
" Conv2d-375 [-1, 170, 17, 12] 202,980\n",
" BatchNorm2d-376 [-1, 170, 17, 12] 340\n",
" BasicConv2d-377 [-1, 170, 17, 12] 0\n",
" ZeroPad2d-378 [-1, 1194, 17, 12] 0\n",
" Conv2d-379 [-1, 256, 17, 12] 305,664\n",
" BatchNorm2d-380 [-1, 256, 17, 12] 512\n",
" BasicConv2d-381 [-1, 256, 17, 12] 0\n",
" ZeroPad2d-382 [-1, 256, 19, 14] 0\n",
" Conv2d-383 [-1, 512, 17, 12] 1,179,648\n",
" BatchNorm2d-384 [-1, 512, 17, 12] 1,024\n",
" BasicConv2d-385 [-1, 512, 17, 12] 0\n",
" ZeroPad2d-386 [-1, 1194, 17, 12] 0\n",
" Conv2d-387 [-1, 170, 17, 12] 202,980\n",
" BatchNorm2d-388 [-1, 170, 17, 12] 340\n",
" BasicConv2d-389 [-1, 170, 17, 12] 0\n",
" ZeroPad2d-390 [-1, 170, 19, 14] 0\n",
" Conv2d-391 [-1, 256, 17, 12] 391,680\n",
" BatchNorm2d-392 [-1, 256, 17, 12] 512\n",
" BasicConv2d-393 [-1, 256, 17, 12] 0\n",
" ZeroPad2d-394 [-1, 256, 19, 14] 0\n",
" Conv2d-395 [-1, 512, 17, 12] 1,179,648\n",
" BatchNorm2d-396 [-1, 512, 17, 12] 1,024\n",
" BasicConv2d-397 [-1, 512, 17, 12] 0\n",
" BatchNorm2d-398 [-1, 1194, 17, 12] 2,388\n",
"Multi_2D_CNN_block-399 [-1, 1194, 17, 12] 0\n",
" MaxPool2d-400 [-1, 1194, 8, 12] 0\n",
" ZeroPad2d-401 [-1, 1194, 8, 12] 0\n",
" Conv2d-402 [-1, 256, 8, 12] 305,664\n",
" BatchNorm2d-403 [-1, 256, 8, 12] 512\n",
" BasicConv2d-404 [-1, 256, 8, 12] 0\n",
" ZeroPad2d-405 [-1, 1194, 8, 12] 0\n",
" Conv2d-406 [-1, 384, 8, 12] 458,496\n",
" BatchNorm2d-407 [-1, 384, 8, 12] 768\n",
" BasicConv2d-408 [-1, 384, 8, 12] 0\n",
" ZeroPad2d-409 [-1, 384, 10, 14] 0\n",
" Conv2d-410 [-1, 768, 8, 12] 2,654,208\n",
" BatchNorm2d-411 [-1, 768, 8, 12] 1,536\n",
" BasicConv2d-412 [-1, 768, 8, 12] 0\n",
" ZeroPad2d-413 [-1, 1194, 8, 12] 0\n",
" Conv2d-414 [-1, 256, 8, 12] 305,664\n",
" BatchNorm2d-415 [-1, 256, 8, 12] 512\n",
" BasicConv2d-416 [-1, 256, 8, 12] 0\n",
" ZeroPad2d-417 [-1, 256, 10, 14] 0\n",
" Conv2d-418 [-1, 384, 8, 12] 884,736\n",
" BatchNorm2d-419 [-1, 384, 8, 12] 768\n",
" BasicConv2d-420 [-1, 384, 8, 12] 0\n",
" ZeroPad2d-421 [-1, 384, 10, 14] 0\n",
" Conv2d-422 [-1, 768, 8, 12] 2,654,208\n",
" BatchNorm2d-423 [-1, 768, 8, 12] 1,536\n",
" BasicConv2d-424 [-1, 768, 8, 12] 0\n",
" BatchNorm2d-425 [-1, 1792, 8, 12] 3,584\n",
"Multi_2D_CNN_block-426 [-1, 1792, 8, 12] 0\n",
" ZeroPad2d-427 [-1, 1792, 8, 12] 0\n",
" Conv2d-428 [-1, 298, 8, 12] 534,016\n",
" BatchNorm2d-429 [-1, 298, 8, 12] 596\n",
" BasicConv2d-430 [-1, 298, 8, 12] 0\n",
" ZeroPad2d-431 [-1, 1792, 8, 12] 0\n",
" Conv2d-432 [-1, 448, 8, 12] 802,816\n",
" BatchNorm2d-433 [-1, 448, 8, 12] 896\n",
" BasicConv2d-434 [-1, 448, 8, 12] 0\n",
" ZeroPad2d-435 [-1, 448, 10, 14] 0\n",
" Conv2d-436 [-1, 896, 8, 12] 3,612,672\n",
" BatchNorm2d-437 [-1, 896, 8, 12] 1,792\n",
" BasicConv2d-438 [-1, 896, 8, 12] 0\n",
" ZeroPad2d-439 [-1, 1792, 8, 12] 0\n",
" Conv2d-440 [-1, 298, 8, 12] 534,016\n",
" BatchNorm2d-441 [-1, 298, 8, 12] 596\n",
" BasicConv2d-442 [-1, 298, 8, 12] 0\n",
" ZeroPad2d-443 [-1, 298, 10, 14] 0\n",
" Conv2d-444 [-1, 448, 8, 12] 1,201,536\n",
" BatchNorm2d-445 [-1, 448, 8, 12] 896\n",
" BasicConv2d-446 [-1, 448, 8, 12] 0\n",
" ZeroPad2d-447 [-1, 448, 10, 14] 0\n",
" Conv2d-448 [-1, 896, 8, 12] 3,612,672\n",
" BatchNorm2d-449 [-1, 896, 8, 12] 1,792\n",
" BasicConv2d-450 [-1, 896, 8, 12] 0\n",
" BatchNorm2d-451 [-1, 2090, 8, 12] 4,180\n",
"Multi_2D_CNN_block-452 [-1, 2090, 8, 12] 0\n",
" ZeroPad2d-453 [-1, 2090, 8, 12] 0\n",
" Conv2d-454 [-1, 341, 8, 12] 712,690\n",
" BatchNorm2d-455 [-1, 341, 8, 12] 682\n",
" BasicConv2d-456 [-1, 341, 8, 12] 0\n",
" ZeroPad2d-457 [-1, 2090, 8, 12] 0\n",
" Conv2d-458 [-1, 512, 8, 12] 1,070,080\n",
" BatchNorm2d-459 [-1, 512, 8, 12] 1,024\n",
" BasicConv2d-460 [-1, 512, 8, 12] 0\n",
" ZeroPad2d-461 [-1, 512, 10, 14] 0\n",
" Conv2d-462 [-1, 1024, 8, 12] 4,718,592\n",
" BatchNorm2d-463 [-1, 1024, 8, 12] 2,048\n",
" BasicConv2d-464 [-1, 1024, 8, 12] 0\n",
" ZeroPad2d-465 [-1, 2090, 8, 12] 0\n",
" Conv2d-466 [-1, 341, 8, 12] 712,690\n",
" BatchNorm2d-467 [-1, 341, 8, 12] 682\n",
" BasicConv2d-468 [-1, 341, 8, 12] 0\n",
" ZeroPad2d-469 [-1, 341, 10, 14] 0\n",
" Conv2d-470 [-1, 512, 8, 12] 1,571,328\n",
" BatchNorm2d-471 [-1, 512, 8, 12] 1,024\n",
" BasicConv2d-472 [-1, 512, 8, 12] 0\n",
" ZeroPad2d-473 [-1, 512, 10, 14] 0\n",
" Conv2d-474 [-1, 1024, 8, 12] 4,718,592\n",
" BatchNorm2d-475 [-1, 1024, 8, 12] 2,048\n",
" BasicConv2d-476 [-1, 1024, 8, 12] 0\n",
" BatchNorm2d-477 [-1, 2389, 8, 12] 4,778\n",
"Multi_2D_CNN_block-478 [-1, 2389, 8, 12] 0\n",
"AdaptiveAvgPool2d-479 [-1, 2389, 1, 1] 0\n",
" Flatten-480 [-1, 2389] 0\n",
" Dropout-481 [-1, 2389] 0\n",
" Linear-482 [-1, 1] 2,390\n",
"================================================================\n",
"Total params: 49,719,798\n",
"Trainable params: 49,719,798\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n",
"Input size (MB): 0.11\n",
"Forward/backward pass size (MB): 884.62\n",
"Params size (MB): 189.67\n",
"Estimated Total Size (MB): 1074.40\n",
"----------------------------------------------------------------\n"
]
}
],
"source": [
"summary(model, input_size=(1, 2500, 12))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"class Inception_Tabular_7_1(nn.Module):\n",
"\n",
" def __init__(self, initial_kernel_num=64):\n",
" super(Inception_Tabular_7_1, self).__init__()\n",
"\n",
" multi_2d_cnn = Multi_2D_CNN_block\n",
" conv_block = BasicConv2d\n",
"\n",
" self.conv_1 = conv_block(1, 64, kernel_size=(7, 1), stride=(2, 1))\n",
"\n",
" self.multi_2d_cnn_1a = nn.Sequential(\n",
" multi_2d_cnn(in_channels=64, num_kernel=initial_kernel_num),\n",
" multi_2d_cnn(in_channels=149, num_kernel=initial_kernel_num),\n",
" nn.MaxPool2d(kernel_size=(3, 1))\n",
" )\n",
"\n",
" self.multi_2d_cnn_1b = nn.Sequential(\n",
" multi_2d_cnn(in_channels=149, num_kernel=initial_kernel_num * 1.5),\n",
" multi_2d_cnn(in_channels=224, num_kernel=initial_kernel_num * 1.5),\n",
" nn.MaxPool2d(kernel_size=(3, 1))\n",
" )\n",
" self.multi_2d_cnn_1c = nn.Sequential(\n",
" multi_2d_cnn(in_channels=224, num_kernel=initial_kernel_num * 2),\n",
" multi_2d_cnn(in_channels=298, num_kernel=initial_kernel_num * 2),\n",
" nn.MaxPool2d(kernel_size=(2, 1))\n",
" )\n",
"\n",
" self.multi_2d_cnn_2a = nn.Sequential(\n",
" multi_2d_cnn(in_channels=298, num_kernel=initial_kernel_num * 3),\n",
" multi_2d_cnn(in_channels=448, num_kernel=initial_kernel_num * 3),\n",
" multi_2d_cnn(in_channels=448, num_kernel=initial_kernel_num * 4),\n",
" nn.MaxPool2d(kernel_size=(2, 1))\n",
" )\n",
" self.multi_2d_cnn_2b = nn.Sequential(\n",
" multi_2d_cnn(in_channels=597, num_kernel=initial_kernel_num * 5),\n",
" multi_2d_cnn(in_channels=746, num_kernel=initial_kernel_num * 6),\n",
" multi_2d_cnn(in_channels=896, num_kernel=initial_kernel_num * 7),\n",
" nn.MaxPool2d(kernel_size=(2, 1))\n",
" )\n",
" self.multi_2d_cnn_2c = nn.Sequential(\n",
" multi_2d_cnn(in_channels=1045, num_kernel=initial_kernel_num * 8),\n",
" multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 8),\n",
" multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 8),\n",
" nn.MaxPool2d(kernel_size=(2, 1))\n",
" )\n",
" self.multi_2d_cnn_2d = nn.Sequential(\n",
" multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 12),\n",
" multi_2d_cnn(in_channels=1792, num_kernel=initial_kernel_num * 14),\n",
" multi_2d_cnn(in_channels=2090, num_kernel=initial_kernel_num * 16),\n",
" )\n",
" self.output = nn.Sequential(\n",
" nn.AdaptiveAvgPool2d((1, 1)),\n",
" nn.Flatten(),\n",
" nn.Dropout(0.5),\n",
" nn.Linear(2389, 1),\n",
" # nn.Sigmoid()\n",
" )\n",
"\n",
" def forward(self, x):\n",
" x = self.conv_1(x)\n",
" # N x 1250 x 12 x 64 tensor\n",
" x = self.multi_2d_cnn_1a(x)\n",
" # N x 416 x 12 x 149 tensor\n",
" x = self.multi_2d_cnn_1b(x)\n",
" # N x 138 x 12 x 224 tensor\n",
" x = self.multi_2d_cnn_1c(x)\n",
" # N x 69 x 12 x 298\n",
" x = self.multi_2d_cnn_2a(x)\n",
"\n",
" x = self.multi_2d_cnn_2b(x)\n",
"\n",
" x = self.multi_2d_cnn_2c(x)\n",
"\n",
" x = self.multi_2d_cnn_2d(x)\n",
"\n",
" x = self.output(x)\n",
"\n",
" return x\n",
"\n",
"class Inception_Tabular_15_3(nn.Module):\n",
"\n",
" def __init__(self, initial_kernel_num=64):\n",
" super(Inception_Tabular_15_3, self).__init__()\n",
"\n",
" multi_2d_cnn = Multi_2D_CNN_block\n",
" conv_block = BasicConv2d\n",
"\n",
" self.conv_1 = conv_block(1, 64, kernel_size=(15, 3), stride=(2, 1))\n",
"\n",
" self.multi_2d_cnn_1a = nn.Sequential(\n",
" multi_2d_cnn(in_channels=64, num_kernel=initial_kernel_num),\n",
" multi_2d_cnn(in_channels=149, num_kernel=initial_kernel_num),\n",
" nn.MaxPool2d(kernel_size=(3, 1))\n",
" )\n",
"\n",
" self.multi_2d_cnn_1b = nn.Sequential(\n",
" multi_2d_cnn(in_channels=149, num_kernel=initial_kernel_num * 1.5),\n",
" multi_2d_cnn(in_channels=224, num_kernel=initial_kernel_num * 1.5),\n",
" nn.MaxPool2d(kernel_size=(3, 1))\n",
" )\n",
" self.multi_2d_cnn_1c = nn.Sequential(\n",
" multi_2d_cnn(in_channels=224, num_kernel=initial_kernel_num * 2),\n",
" multi_2d_cnn(in_channels=298, num_kernel=initial_kernel_num * 2),\n",
" nn.MaxPool2d(kernel_size=(2, 1))\n",
" )\n",
"\n",
" self.multi_2d_cnn_2a = nn.Sequential(\n",
" multi_2d_cnn(in_channels=298, num_kernel=initial_kernel_num * 3),\n",
" multi_2d_cnn(in_channels=448, num_kernel=initial_kernel_num * 3),\n",
" multi_2d_cnn(in_channels=448, num_kernel=initial_kernel_num * 4),\n",
" nn.MaxPool2d(kernel_size=(2, 1))\n",
" )\n",
" self.multi_2d_cnn_2b = nn.Sequential(\n",
" multi_2d_cnn(in_channels=597, num_kernel=initial_kernel_num * 5),\n",
" multi_2d_cnn(in_channels=746, num_kernel=initial_kernel_num * 6),\n",
" multi_2d_cnn(in_channels=896, num_kernel=initial_kernel_num * 7),\n",
" nn.MaxPool2d(kernel_size=(2, 1))\n",
" )\n",
" self.multi_2d_cnn_2c = nn.Sequential(\n",
" multi_2d_cnn(in_channels=1045, num_kernel=initial_kernel_num * 8),\n",
" multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 8),\n",
" multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 8),\n",
" nn.MaxPool2d(kernel_size=(2, 1))\n",
" )\n",
" self.multi_2d_cnn_2d = nn.Sequential(\n",
" multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 12),\n",
" multi_2d_cnn(in_channels=1792, num_kernel=initial_kernel_num * 14),\n",
" multi_2d_cnn(in_channels=2090, num_kernel=initial_kernel_num * 16),\n",
" )\n",
" self.output = nn.Sequential(\n",
" nn.AdaptiveAvgPool2d((1, 1)),\n",
" nn.Flatten(),\n",
" nn.Dropout(0.5),\n",
" nn.Linear(2389, 1),\n",
" # nn.Sigmoid()\n",
" )\n",
"\n",
" def forward(self, x):\n",
" x = self.conv_1(x)\n",
" # N x 1250 x 12 x 64 tensor\n",
" x = self.multi_2d_cnn_1a(x)\n",
" # N x 416 x 12 x 149 tensor\n",
" x = self.multi_2d_cnn_1b(x)\n",
" # N x 138 x 12 x 224 tensor\n",
" x = self.multi_2d_cnn_1c(x)\n",
" # N x 69 x 12 x 298\n",
" x = self.multi_2d_cnn_2a(x)\n",
"\n",
" x = self.multi_2d_cnn_2b(x)\n",
"\n",
" x = self.multi_2d_cnn_2c(x)\n",
"\n",
" x = self.multi_2d_cnn_2d(x)\n",
"\n",
" x = self.output(x)\n",
"\n",
" return x\n",
"\n",
"\n",
"class Inception_Tabular_15_1(nn.Module):\n",
"\n",
" def __init__(self, initial_kernel_num=64):\n",
" super(Inception_Tabular_15_1, self).__init__()\n",
"\n",
" multi_2d_cnn = Multi_2D_CNN_block\n",
" conv_block = BasicConv2d\n",
"\n",
" self.conv_1 = conv_block(1, 64, kernel_size=(15, 1), stride=(2, 1))\n",
"\n",
" self.multi_2d_cnn_1a = nn.Sequential(\n",
" multi_2d_cnn(in_channels=64, num_kernel=initial_kernel_num),\n",
" multi_2d_cnn(in_channels=149, num_kernel=initial_kernel_num),\n",
" nn.MaxPool2d(kernel_size=(3, 1))\n",
" )\n",
"\n",
" self.multi_2d_cnn_1b = nn.Sequential(\n",
" multi_2d_cnn(in_channels=149, num_kernel=initial_kernel_num * 1.5),\n",
" multi_2d_cnn(in_channels=224, num_kernel=initial_kernel_num * 1.5),\n",
" nn.MaxPool2d(kernel_size=(3, 1))\n",
" )\n",
" self.multi_2d_cnn_1c = nn.Sequential(\n",
" multi_2d_cnn(in_channels=224, num_kernel=initial_kernel_num * 2),\n",
" multi_2d_cnn(in_channels=298, num_kernel=initial_kernel_num * 2),\n",
" nn.MaxPool2d(kernel_size=(2, 1))\n",
" )\n",
"\n",
" self.multi_2d_cnn_2a = nn.Sequential(\n",
" multi_2d_cnn(in_channels=298, num_kernel=initial_kernel_num * 3),\n",
" multi_2d_cnn(in_channels=448, num_kernel=initial_kernel_num * 3),\n",
" multi_2d_cnn(in_channels=448, num_kernel=initial_kernel_num * 4),\n",
" nn.MaxPool2d(kernel_size=(2, 1))\n",
" )\n",
" self.multi_2d_cnn_2b = nn.Sequential(\n",
" multi_2d_cnn(in_channels=597, num_kernel=initial_kernel_num * 5),\n",
" multi_2d_cnn(in_channels=746, num_kernel=initial_kernel_num * 6),\n",
" multi_2d_cnn(in_channels=896, num_kernel=initial_kernel_num * 7),\n",
" nn.MaxPool2d(kernel_size=(2, 1))\n",
" )\n",
" self.multi_2d_cnn_2c = nn.Sequential(\n",
" multi_2d_cnn(in_channels=1045, num_kernel=initial_kernel_num * 8),\n",
" multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 8),\n",
" multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 8),\n",
" nn.MaxPool2d(kernel_size=(2, 1))\n",
" )\n",
" self.multi_2d_cnn_2d = nn.Sequential(\n",
" multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 12),\n",
" multi_2d_cnn(in_channels=1792, num_kernel=initial_kernel_num * 14),\n",
" multi_2d_cnn(in_channels=2090, num_kernel=initial_kernel_num * 16),\n",
" )\n",
" self.output = nn.Sequential(\n",
" nn.AdaptiveAvgPool2d((1, 1)),\n",
" nn.Flatten(),\n",
" nn.Dropout(0.5),\n",
" nn.Linear(2389, 1),\n",
" # nn.Sigmoid()\n",
" )\n",
"\n",
" def forward(self, x):\n",
" x = self.conv_1(x)\n",
" # N x 1250 x 12 x 64 tensor\n",
" x = self.multi_2d_cnn_1a(x)\n",
" # N x 416 x 12 x 149 tensor\n",
" x = self.multi_2d_cnn_1b(x)\n",
" # N x 138 x 12 x 224 tensor\n",
" x = self.multi_2d_cnn_1c(x)\n",
" # N x 69 x 12 x 298\n",
" x = self.multi_2d_cnn_2a(x)\n",
"\n",
" x = self.multi_2d_cnn_2b(x)\n",
"\n",
" x = self.multi_2d_cnn_2c(x)\n",
"\n",
" x = self.multi_2d_cnn_2d(x)\n",
"\n",
" x = self.output(x)\n",
"\n",
" return x\n",
"\n",
"class Inception_Tabular_5_Linears(nn.Module):\n",
"\n",
" def __init__(self, initial_kernel_num=64):\n",
" super(Inception_Tabular_5_Linears, self).__init__()\n",
"\n",
" multi_2d_cnn = Multi_2D_CNN_block\n",
" conv_block = BasicConv2d\n",
"\n",
" self.conv_1 = conv_block(1, 64, kernel_size=(7, 3), stride=(2, 1))\n",
"\n",
" self.multi_2d_cnn_1a = nn.Sequential(\n",
" multi_2d_cnn(in_channels=64, num_kernel=initial_kernel_num),\n",
" multi_2d_cnn(in_channels=149, num_kernel=initial_kernel_num),\n",
" nn.MaxPool2d(kernel_size=(3, 1))\n",
" )\n",
"\n",
" self.multi_2d_cnn_1b = nn.Sequential(\n",
" multi_2d_cnn(in_channels=149, num_kernel=initial_kernel_num * 1.5),\n",
" multi_2d_cnn(in_channels=224, num_kernel=initial_kernel_num * 1.5),\n",
" nn.MaxPool2d(kernel_size=(3, 1))\n",
" )\n",
" self.multi_2d_cnn_1c = nn.Sequential(\n",
" multi_2d_cnn(in_channels=224, num_kernel=initial_kernel_num * 2),\n",
" multi_2d_cnn(in_channels=298, num_kernel=initial_kernel_num * 2),\n",
" nn.MaxPool2d(kernel_size=(2, 1))\n",
" )\n",
"\n",
" self.multi_2d_cnn_2a = nn.Sequential(\n",
" multi_2d_cnn(in_channels=298, num_kernel=initial_kernel_num * 3),\n",
" multi_2d_cnn(in_channels=448, num_kernel=initial_kernel_num * 3),\n",
" multi_2d_cnn(in_channels=448, num_kernel=initial_kernel_num * 4),\n",
" nn.MaxPool2d(kernel_size=(2, 1))\n",
" )\n",
" self.multi_2d_cnn_2b = nn.Sequential(\n",
" multi_2d_cnn(in_channels=597, num_kernel=initial_kernel_num * 5),\n",
" multi_2d_cnn(in_channels=746, num_kernel=initial_kernel_num * 6),\n",
" multi_2d_cnn(in_channels=896, num_kernel=initial_kernel_num * 7),\n",
" nn.MaxPool2d(kernel_size=(2, 1))\n",
" )\n",
" self.multi_2d_cnn_2c = nn.Sequential(\n",
" multi_2d_cnn(in_channels=1045, num_kernel=initial_kernel_num * 8),\n",
" multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 8),\n",
" multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 8),\n",
" nn.MaxPool2d(kernel_size=(2, 1))\n",
" )\n",
" self.multi_2d_cnn_2d = nn.Sequential(\n",
" multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 12),\n",
" multi_2d_cnn(in_channels=1792, num_kernel=initial_kernel_num * 14),\n",
" multi_2d_cnn(in_channels=2090, num_kernel=initial_kernel_num * 16),\n",
" )\n",
" self.output = nn.Sequential(\n",
" nn.AdaptiveAvgPool2d((1, 1)),\n",
" nn.Flatten(),\n",
" nn.Dropout(0.5),\n",
" nn.Linear(2389, 1200),\n",
" nn.Linear(1200, 400),\n",
" nn.Linear(400, 100),\n",
" nn.Linear(100, 20),\n",
" nn.Linear(20, 1),\n",
" # nn.Sigmoid()\n",
" )\n",
"\n",
" def forward(self, x):\n",
" x = self.conv_1(x)\n",
" # N x 1250 x 12 x 64 tensor\n",
" x = self.multi_2d_cnn_1a(x)\n",
" # N x 416 x 12 x 149 tensor\n",
" x = self.multi_2d_cnn_1b(x)\n",
" # N x 138 x 12 x 224 tensor\n",
" x = self.multi_2d_cnn_1c(x)\n",
" # N x 69 x 12 x 298\n",
" x = self.multi_2d_cnn_2a(x)\n",
"\n",
" x = self.multi_2d_cnn_2b(x)\n",
"\n",
" x = self.multi_2d_cnn_2c(x)\n",
"\n",
" x = self.multi_2d_cnn_2d(x)\n",
"\n",
" x = self.output(x)\n",
"\n",
" return x\n",
"\n",
"class Inception_Tabular_5_Linears_BatchNorm(nn.Module):\n",
"\n",
" def __init__(self, initial_kernel_num=64):\n",
" super(Inception_Tabular_5_Linears, self).__init__()\n",
"\n",
" multi_2d_cnn = Multi_2D_CNN_block\n",
" conv_block = BasicConv2d\n",
"\n",
" self.conv_1 = conv_block(1, 64, kernel_size=(7, 3), stride=(2, 1))\n",
"\n",
" self.multi_2d_cnn_1a = nn.Sequential(\n",
" multi_2d_cnn(in_channels=64, num_kernel=initial_kernel_num),\n",
" multi_2d_cnn(in_channels=149, num_kernel=initial_kernel_num),\n",
" nn.MaxPool2d(kernel_size=(3, 1))\n",
" )\n",
"\n",
" self.multi_2d_cnn_1b = nn.Sequential(\n",
" multi_2d_cnn(in_channels=149, num_kernel=initial_kernel_num * 1.5),\n",
" multi_2d_cnn(in_channels=224, num_kernel=initial_kernel_num * 1.5),\n",
" nn.MaxPool2d(kernel_size=(3, 1))\n",
" )\n",
" self.multi_2d_cnn_1c = nn.Sequential(\n",
" multi_2d_cnn(in_channels=224, num_kernel=initial_kernel_num * 2),\n",
" multi_2d_cnn(in_channels=298, num_kernel=initial_kernel_num * 2),\n",
" nn.MaxPool2d(kernel_size=(2, 1))\n",
" )\n",
"\n",
" self.multi_2d_cnn_2a = nn.Sequential(\n",
" multi_2d_cnn(in_channels=298, num_kernel=initial_kernel_num * 3),\n",
" multi_2d_cnn(in_channels=448, num_kernel=initial_kernel_num * 3),\n",
" multi_2d_cnn(in_channels=448, num_kernel=initial_kernel_num * 4),\n",
" nn.MaxPool2d(kernel_size=(2, 1))\n",
" )\n",
" self.multi_2d_cnn_2b = nn.Sequential(\n",
" multi_2d_cnn(in_channels=597, num_kernel=initial_kernel_num * 5),\n",
" multi_2d_cnn(in_channels=746, num_kernel=initial_kernel_num * 6),\n",
" multi_2d_cnn(in_channels=896, num_kernel=initial_kernel_num * 7),\n",
" nn.MaxPool2d(kernel_size=(2, 1))\n",
" )\n",
" self.multi_2d_cnn_2c = nn.Sequential(\n",
" multi_2d_cnn(in_channels=1045, num_kernel=initial_kernel_num * 8),\n",
" multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 8),\n",
" multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 8),\n",
" nn.MaxPool2d(kernel_size=(2, 1))\n",
" )\n",
" self.multi_2d_cnn_2d = nn.Sequential(\n",
" multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 12),\n",
" multi_2d_cnn(in_channels=1792, num_kernel=initial_kernel_num * 14),\n",
" multi_2d_cnn(in_channels=2090, num_kernel=initial_kernel_num * 16),\n",
" )\n",
" self.output = nn.Sequential(\n",
" nn.AdaptiveAvgPool2d((1, 1)),\n",
" nn.Flatten(),\n",
" nn.Dropout(0.5),\n",
" nn.Linear(2389, 1200),\n",
" nn.BatchNorm1d(1200),\n",
" nn.Linear(1200, 400),\n",
" nn.BatchNorm1d(400),\n",
" nn.Linear(400, 100),\n",
" nn.BatchNorm1d(100),\n",
" nn.Linear(100, 20),\n",
" nn.BatchNorm1d(20),\n",
" nn.Linear(20, 1)\n",
" # nn.Sigmoid()\n",
" )\n",
"\n",
" def forward(self, x):\n",
" x = self.conv_1(x)\n",
" # N x 1250 x 12 x 64 tensor\n",
" x = self.multi_2d_cnn_1a(x)\n",
" # N x 416 x 12 x 149 tensor\n",
" x = self.multi_2d_cnn_1b(x)\n",
" # N x 138 x 12 x 224 tensor\n",
" x = self.multi_2d_cnn_1c(x)\n",
" # N x 69 x 12 x 298\n",
" x = self.multi_2d_cnn_2a(x)\n",
"\n",
" x = self.multi_2d_cnn_2b(x)\n",
"\n",
" x = self.multi_2d_cnn_2c(x)\n",
"\n",
" x = self.multi_2d_cnn_2d(x)\n",
"\n",
" x = self.output(x)\n",
"\n",
" return x\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"model = Inception_Tabular_15_3()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" ZeroPad2d-1 [-1, 1, 2514, 14] 0\n",
" Conv2d-2 [-1, 64, 1250, 12] 2,880\n",
" BatchNorm2d-3 [-1, 64, 1250, 12] 128\n",
" BasicConv2d-4 [-1, 64, 1250, 12] 0\n",
" ZeroPad2d-5 [-1, 64, 1250, 12] 0\n",
" Conv2d-6 [-1, 21, 1250, 12] 1,344\n",
" BatchNorm2d-7 [-1, 21, 1250, 12] 42\n",
" BasicConv2d-8 [-1, 21, 1250, 12] 0\n",
" ZeroPad2d-9 [-1, 64, 1250, 12] 0\n",
" Conv2d-10 [-1, 32, 1250, 12] 2,048\n",
" BatchNorm2d-11 [-1, 32, 1250, 12] 64\n",
" BasicConv2d-12 [-1, 32, 1250, 12] 0\n",
" ZeroPad2d-13 [-1, 32, 1252, 14] 0\n",
" Conv2d-14 [-1, 64, 1250, 12] 18,432\n",
" BatchNorm2d-15 [-1, 64, 1250, 12] 128\n",
" BasicConv2d-16 [-1, 64, 1250, 12] 0\n",
" ZeroPad2d-17 [-1, 64, 1250, 12] 0\n",
" Conv2d-18 [-1, 21, 1250, 12] 1,344\n",
" BatchNorm2d-19 [-1, 21, 1250, 12] 42\n",
" BasicConv2d-20 [-1, 21, 1250, 12] 0\n",
" ZeroPad2d-21 [-1, 21, 1252, 14] 0\n",
" Conv2d-22 [-1, 32, 1250, 12] 6,048\n",
" BatchNorm2d-23 [-1, 32, 1250, 12] 64\n",
" BasicConv2d-24 [-1, 32, 1250, 12] 0\n",
" ZeroPad2d-25 [-1, 32, 1252, 14] 0\n",
" Conv2d-26 [-1, 64, 1250, 12] 18,432\n",
" BatchNorm2d-27 [-1, 64, 1250, 12] 128\n",
" BasicConv2d-28 [-1, 64, 1250, 12] 0\n",
" BatchNorm2d-29 [-1, 149, 1250, 12] 298\n",
"Multi_2D_CNN_block-30 [-1, 149, 1250, 12] 0\n",
" ZeroPad2d-31 [-1, 149, 1250, 12] 0\n",
" Conv2d-32 [-1, 21, 1250, 12] 3,129\n",
" BatchNorm2d-33 [-1, 21, 1250, 12] 42\n",
" BasicConv2d-34 [-1, 21, 1250, 12] 0\n",
" ZeroPad2d-35 [-1, 149, 1250, 12] 0\n",
" Conv2d-36 [-1, 32, 1250, 12] 4,768\n",
" BatchNorm2d-37 [-1, 32, 1250, 12] 64\n",
" BasicConv2d-38 [-1, 32, 1250, 12] 0\n",
" ZeroPad2d-39 [-1, 32, 1252, 14] 0\n",
" Conv2d-40 [-1, 64, 1250, 12] 18,432\n",
" BatchNorm2d-41 [-1, 64, 1250, 12] 128\n",
" BasicConv2d-42 [-1, 64, 1250, 12] 0\n",
" ZeroPad2d-43 [-1, 149, 1250, 12] 0\n",
" Conv2d-44 [-1, 21, 1250, 12] 3,129\n",
" BatchNorm2d-45 [-1, 21, 1250, 12] 42\n",
" BasicConv2d-46 [-1, 21, 1250, 12] 0\n",
" ZeroPad2d-47 [-1, 21, 1252, 14] 0\n",
" Conv2d-48 [-1, 32, 1250, 12] 6,048\n",
" BatchNorm2d-49 [-1, 32, 1250, 12] 64\n",
" BasicConv2d-50 [-1, 32, 1250, 12] 0\n",
" ZeroPad2d-51 [-1, 32, 1252, 14] 0\n",
" Conv2d-52 [-1, 64, 1250, 12] 18,432\n",
" BatchNorm2d-53 [-1, 64, 1250, 12] 128\n",
" BasicConv2d-54 [-1, 64, 1250, 12] 0\n",
" BatchNorm2d-55 [-1, 149, 1250, 12] 298\n",
"Multi_2D_CNN_block-56 [-1, 149, 1250, 12] 0\n",
" MaxPool2d-57 [-1, 149, 416, 12] 0\n",
" ZeroPad2d-58 [-1, 149, 416, 12] 0\n",
" Conv2d-59 [-1, 32, 416, 12] 4,768\n",
" BatchNorm2d-60 [-1, 32, 416, 12] 64\n",
" BasicConv2d-61 [-1, 32, 416, 12] 0\n",
" ZeroPad2d-62 [-1, 149, 416, 12] 0\n",
" Conv2d-63 [-1, 48, 416, 12] 7,152\n",
" BatchNorm2d-64 [-1, 48, 416, 12] 96\n",
" BasicConv2d-65 [-1, 48, 416, 12] 0\n",
" ZeroPad2d-66 [-1, 48, 418, 14] 0\n",
" Conv2d-67 [-1, 96, 416, 12] 41,472\n",
" BatchNorm2d-68 [-1, 96, 416, 12] 192\n",
" BasicConv2d-69 [-1, 96, 416, 12] 0\n",
" ZeroPad2d-70 [-1, 149, 416, 12] 0\n",
" Conv2d-71 [-1, 32, 416, 12] 4,768\n",
" BatchNorm2d-72 [-1, 32, 416, 12] 64\n",
" BasicConv2d-73 [-1, 32, 416, 12] 0\n",
" ZeroPad2d-74 [-1, 32, 418, 14] 0\n",
" Conv2d-75 [-1, 48, 416, 12] 13,824\n",
" BatchNorm2d-76 [-1, 48, 416, 12] 96\n",
" BasicConv2d-77 [-1, 48, 416, 12] 0\n",
" ZeroPad2d-78 [-1, 48, 418, 14] 0\n",
" Conv2d-79 [-1, 96, 416, 12] 41,472\n",
" BatchNorm2d-80 [-1, 96, 416, 12] 192\n",
" BasicConv2d-81 [-1, 96, 416, 12] 0\n",
" BatchNorm2d-82 [-1, 224, 416, 12] 448\n",
"Multi_2D_CNN_block-83 [-1, 224, 416, 12] 0\n",
" ZeroPad2d-84 [-1, 224, 416, 12] 0\n",
" Conv2d-85 [-1, 32, 416, 12] 7,168\n",
" BatchNorm2d-86 [-1, 32, 416, 12] 64\n",
" BasicConv2d-87 [-1, 32, 416, 12] 0\n",
" ZeroPad2d-88 [-1, 224, 416, 12] 0\n",
" Conv2d-89 [-1, 48, 416, 12] 10,752\n",
" BatchNorm2d-90 [-1, 48, 416, 12] 96\n",
" BasicConv2d-91 [-1, 48, 416, 12] 0\n",
" ZeroPad2d-92 [-1, 48, 418, 14] 0\n",
" Conv2d-93 [-1, 96, 416, 12] 41,472\n",
" BatchNorm2d-94 [-1, 96, 416, 12] 192\n",
" BasicConv2d-95 [-1, 96, 416, 12] 0\n",
" ZeroPad2d-96 [-1, 224, 416, 12] 0\n",
" Conv2d-97 [-1, 32, 416, 12] 7,168\n",
" BatchNorm2d-98 [-1, 32, 416, 12] 64\n",
" BasicConv2d-99 [-1, 32, 416, 12] 0\n",
" ZeroPad2d-100 [-1, 32, 418, 14] 0\n",
" Conv2d-101 [-1, 48, 416, 12] 13,824\n",
" BatchNorm2d-102 [-1, 48, 416, 12] 96\n",
" BasicConv2d-103 [-1, 48, 416, 12] 0\n",
" ZeroPad2d-104 [-1, 48, 418, 14] 0\n",
" Conv2d-105 [-1, 96, 416, 12] 41,472\n",
" BatchNorm2d-106 [-1, 96, 416, 12] 192\n",
" BasicConv2d-107 [-1, 96, 416, 12] 0\n",
" BatchNorm2d-108 [-1, 224, 416, 12] 448\n",
"Multi_2D_CNN_block-109 [-1, 224, 416, 12] 0\n",
" MaxPool2d-110 [-1, 224, 138, 12] 0\n",
" ZeroPad2d-111 [-1, 224, 138, 12] 0\n",
" Conv2d-112 [-1, 42, 138, 12] 9,408\n",
" BatchNorm2d-113 [-1, 42, 138, 12] 84\n",
" BasicConv2d-114 [-1, 42, 138, 12] 0\n",
" ZeroPad2d-115 [-1, 224, 138, 12] 0\n",
" Conv2d-116 [-1, 64, 138, 12] 14,336\n",
" BatchNorm2d-117 [-1, 64, 138, 12] 128\n",
" BasicConv2d-118 [-1, 64, 138, 12] 0\n",
" ZeroPad2d-119 [-1, 64, 140, 14] 0\n",
" Conv2d-120 [-1, 128, 138, 12] 73,728\n",
" BatchNorm2d-121 [-1, 128, 138, 12] 256\n",
" BasicConv2d-122 [-1, 128, 138, 12] 0\n",
" ZeroPad2d-123 [-1, 224, 138, 12] 0\n",
" Conv2d-124 [-1, 42, 138, 12] 9,408\n",
" BatchNorm2d-125 [-1, 42, 138, 12] 84\n",
" BasicConv2d-126 [-1, 42, 138, 12] 0\n",
" ZeroPad2d-127 [-1, 42, 140, 14] 0\n",
" Conv2d-128 [-1, 64, 138, 12] 24,192\n",
" BatchNorm2d-129 [-1, 64, 138, 12] 128\n",
" BasicConv2d-130 [-1, 64, 138, 12] 0\n",
" ZeroPad2d-131 [-1, 64, 140, 14] 0\n",
" Conv2d-132 [-1, 128, 138, 12] 73,728\n",
" BatchNorm2d-133 [-1, 128, 138, 12] 256\n",
" BasicConv2d-134 [-1, 128, 138, 12] 0\n",
" BatchNorm2d-135 [-1, 298, 138, 12] 596\n",
"Multi_2D_CNN_block-136 [-1, 298, 138, 12] 0\n",
" ZeroPad2d-137 [-1, 298, 138, 12] 0\n",
" Conv2d-138 [-1, 42, 138, 12] 12,516\n",
" BatchNorm2d-139 [-1, 42, 138, 12] 84\n",
" BasicConv2d-140 [-1, 42, 138, 12] 0\n",
" ZeroPad2d-141 [-1, 298, 138, 12] 0\n",
" Conv2d-142 [-1, 64, 138, 12] 19,072\n",
" BatchNorm2d-143 [-1, 64, 138, 12] 128\n",
" BasicConv2d-144 [-1, 64, 138, 12] 0\n",
" ZeroPad2d-145 [-1, 64, 140, 14] 0\n",
" Conv2d-146 [-1, 128, 138, 12] 73,728\n",
" BatchNorm2d-147 [-1, 128, 138, 12] 256\n",
" BasicConv2d-148 [-1, 128, 138, 12] 0\n",
" ZeroPad2d-149 [-1, 298, 138, 12] 0\n",
" Conv2d-150 [-1, 42, 138, 12] 12,516\n",
" BatchNorm2d-151 [-1, 42, 138, 12] 84\n",
" BasicConv2d-152 [-1, 42, 138, 12] 0\n",
" ZeroPad2d-153 [-1, 42, 140, 14] 0\n",
" Conv2d-154 [-1, 64, 138, 12] 24,192\n",
" BatchNorm2d-155 [-1, 64, 138, 12] 128\n",
" BasicConv2d-156 [-1, 64, 138, 12] 0\n",
" ZeroPad2d-157 [-1, 64, 140, 14] 0\n",
" Conv2d-158 [-1, 128, 138, 12] 73,728\n",
" BatchNorm2d-159 [-1, 128, 138, 12] 256\n",
" BasicConv2d-160 [-1, 128, 138, 12] 0\n",
" BatchNorm2d-161 [-1, 298, 138, 12] 596\n",
"Multi_2D_CNN_block-162 [-1, 298, 138, 12] 0\n",
" MaxPool2d-163 [-1, 298, 69, 12] 0\n",
" ZeroPad2d-164 [-1, 298, 69, 12] 0\n",
" Conv2d-165 [-1, 64, 69, 12] 19,072\n",
" BatchNorm2d-166 [-1, 64, 69, 12] 128\n",
" BasicConv2d-167 [-1, 64, 69, 12] 0\n",
" ZeroPad2d-168 [-1, 298, 69, 12] 0\n",
" Conv2d-169 [-1, 96, 69, 12] 28,608\n",
" BatchNorm2d-170 [-1, 96, 69, 12] 192\n",
" BasicConv2d-171 [-1, 96, 69, 12] 0\n",
" ZeroPad2d-172 [-1, 96, 71, 14] 0\n",
" Conv2d-173 [-1, 192, 69, 12] 165,888\n",
" BatchNorm2d-174 [-1, 192, 69, 12] 384\n",
" BasicConv2d-175 [-1, 192, 69, 12] 0\n",
" ZeroPad2d-176 [-1, 298, 69, 12] 0\n",
" Conv2d-177 [-1, 64, 69, 12] 19,072\n",
" BatchNorm2d-178 [-1, 64, 69, 12] 128\n",
" BasicConv2d-179 [-1, 64, 69, 12] 0\n",
" ZeroPad2d-180 [-1, 64, 71, 14] 0\n",
" Conv2d-181 [-1, 96, 69, 12] 55,296\n",
" BatchNorm2d-182 [-1, 96, 69, 12] 192\n",
" BasicConv2d-183 [-1, 96, 69, 12] 0\n",
" ZeroPad2d-184 [-1, 96, 71, 14] 0\n",
" Conv2d-185 [-1, 192, 69, 12] 165,888\n",
" BatchNorm2d-186 [-1, 192, 69, 12] 384\n",
" BasicConv2d-187 [-1, 192, 69, 12] 0\n",
" BatchNorm2d-188 [-1, 448, 69, 12] 896\n",
"Multi_2D_CNN_block-189 [-1, 448, 69, 12] 0\n",
" ZeroPad2d-190 [-1, 448, 69, 12] 0\n",
" Conv2d-191 [-1, 64, 69, 12] 28,672\n",
" BatchNorm2d-192 [-1, 64, 69, 12] 128\n",
" BasicConv2d-193 [-1, 64, 69, 12] 0\n",
" ZeroPad2d-194 [-1, 448, 69, 12] 0\n",
" Conv2d-195 [-1, 96, 69, 12] 43,008\n",
" BatchNorm2d-196 [-1, 96, 69, 12] 192\n",
" BasicConv2d-197 [-1, 96, 69, 12] 0\n",
" ZeroPad2d-198 [-1, 96, 71, 14] 0\n",
" Conv2d-199 [-1, 192, 69, 12] 165,888\n",
" BatchNorm2d-200 [-1, 192, 69, 12] 384\n",
" BasicConv2d-201 [-1, 192, 69, 12] 0\n",
" ZeroPad2d-202 [-1, 448, 69, 12] 0\n",
" Conv2d-203 [-1, 64, 69, 12] 28,672\n",
" BatchNorm2d-204 [-1, 64, 69, 12] 128\n",
" BasicConv2d-205 [-1, 64, 69, 12] 0\n",
" ZeroPad2d-206 [-1, 64, 71, 14] 0\n",
" Conv2d-207 [-1, 96, 69, 12] 55,296\n",
" BatchNorm2d-208 [-1, 96, 69, 12] 192\n",
" BasicConv2d-209 [-1, 96, 69, 12] 0\n",
" ZeroPad2d-210 [-1, 96, 71, 14] 0\n",
" Conv2d-211 [-1, 192, 69, 12] 165,888\n",
" BatchNorm2d-212 [-1, 192, 69, 12] 384\n",
" BasicConv2d-213 [-1, 192, 69, 12] 0\n",
" BatchNorm2d-214 [-1, 448, 69, 12] 896\n",
"Multi_2D_CNN_block-215 [-1, 448, 69, 12] 0\n",
" ZeroPad2d-216 [-1, 448, 69, 12] 0\n",
" Conv2d-217 [-1, 85, 69, 12] 38,080\n",
" BatchNorm2d-218 [-1, 85, 69, 12] 170\n",
" BasicConv2d-219 [-1, 85, 69, 12] 0\n",
" ZeroPad2d-220 [-1, 448, 69, 12] 0\n",
" Conv2d-221 [-1, 128, 69, 12] 57,344\n",
" BatchNorm2d-222 [-1, 128, 69, 12] 256\n",
" BasicConv2d-223 [-1, 128, 69, 12] 0\n",
" ZeroPad2d-224 [-1, 128, 71, 14] 0\n",
" Conv2d-225 [-1, 256, 69, 12] 294,912\n",
" BatchNorm2d-226 [-1, 256, 69, 12] 512\n",
" BasicConv2d-227 [-1, 256, 69, 12] 0\n",
" ZeroPad2d-228 [-1, 448, 69, 12] 0\n",
" Conv2d-229 [-1, 85, 69, 12] 38,080\n",
" BatchNorm2d-230 [-1, 85, 69, 12] 170\n",
" BasicConv2d-231 [-1, 85, 69, 12] 0\n",
" ZeroPad2d-232 [-1, 85, 71, 14] 0\n",
" Conv2d-233 [-1, 128, 69, 12] 97,920\n",
" BatchNorm2d-234 [-1, 128, 69, 12] 256\n",
" BasicConv2d-235 [-1, 128, 69, 12] 0\n",
" ZeroPad2d-236 [-1, 128, 71, 14] 0\n",
" Conv2d-237 [-1, 256, 69, 12] 294,912\n",
" BatchNorm2d-238 [-1, 256, 69, 12] 512\n",
" BasicConv2d-239 [-1, 256, 69, 12] 0\n",
" BatchNorm2d-240 [-1, 597, 69, 12] 1,194\n",
"Multi_2D_CNN_block-241 [-1, 597, 69, 12] 0\n",
" MaxPool2d-242 [-1, 597, 34, 12] 0\n",
" ZeroPad2d-243 [-1, 597, 34, 12] 0\n",
" Conv2d-244 [-1, 106, 34, 12] 63,282\n",
" BatchNorm2d-245 [-1, 106, 34, 12] 212\n",
" BasicConv2d-246 [-1, 106, 34, 12] 0\n",
" ZeroPad2d-247 [-1, 597, 34, 12] 0\n",
" Conv2d-248 [-1, 160, 34, 12] 95,520\n",
" BatchNorm2d-249 [-1, 160, 34, 12] 320\n",
" BasicConv2d-250 [-1, 160, 34, 12] 0\n",
" ZeroPad2d-251 [-1, 160, 36, 14] 0\n",
" Conv2d-252 [-1, 320, 34, 12] 460,800\n",
" BatchNorm2d-253 [-1, 320, 34, 12] 640\n",
" BasicConv2d-254 [-1, 320, 34, 12] 0\n",
" ZeroPad2d-255 [-1, 597, 34, 12] 0\n",
" Conv2d-256 [-1, 106, 34, 12] 63,282\n",
" BatchNorm2d-257 [-1, 106, 34, 12] 212\n",
" BasicConv2d-258 [-1, 106, 34, 12] 0\n",
" ZeroPad2d-259 [-1, 106, 36, 14] 0\n",
" Conv2d-260 [-1, 160, 34, 12] 152,640\n",
" BatchNorm2d-261 [-1, 160, 34, 12] 320\n",
" BasicConv2d-262 [-1, 160, 34, 12] 0\n",
" ZeroPad2d-263 [-1, 160, 36, 14] 0\n",
" Conv2d-264 [-1, 320, 34, 12] 460,800\n",
" BatchNorm2d-265 [-1, 320, 34, 12] 640\n",
" BasicConv2d-266 [-1, 320, 34, 12] 0\n",
" BatchNorm2d-267 [-1, 746, 34, 12] 1,492\n",
"Multi_2D_CNN_block-268 [-1, 746, 34, 12] 0\n",
" ZeroPad2d-269 [-1, 746, 34, 12] 0\n",
" Conv2d-270 [-1, 128, 34, 12] 95,488\n",
" BatchNorm2d-271 [-1, 128, 34, 12] 256\n",
" BasicConv2d-272 [-1, 128, 34, 12] 0\n",
" ZeroPad2d-273 [-1, 746, 34, 12] 0\n",
" Conv2d-274 [-1, 192, 34, 12] 143,232\n",
" BatchNorm2d-275 [-1, 192, 34, 12] 384\n",
" BasicConv2d-276 [-1, 192, 34, 12] 0\n",
" ZeroPad2d-277 [-1, 192, 36, 14] 0\n",
" Conv2d-278 [-1, 384, 34, 12] 663,552\n",
" BatchNorm2d-279 [-1, 384, 34, 12] 768\n",
" BasicConv2d-280 [-1, 384, 34, 12] 0\n",
" ZeroPad2d-281 [-1, 746, 34, 12] 0\n",
" Conv2d-282 [-1, 128, 34, 12] 95,488\n",
" BatchNorm2d-283 [-1, 128, 34, 12] 256\n",
" BasicConv2d-284 [-1, 128, 34, 12] 0\n",
" ZeroPad2d-285 [-1, 128, 36, 14] 0\n",
" Conv2d-286 [-1, 192, 34, 12] 221,184\n",
" BatchNorm2d-287 [-1, 192, 34, 12] 384\n",
" BasicConv2d-288 [-1, 192, 34, 12] 0\n",
" ZeroPad2d-289 [-1, 192, 36, 14] 0\n",
" Conv2d-290 [-1, 384, 34, 12] 663,552\n",
" BatchNorm2d-291 [-1, 384, 34, 12] 768\n",
" BasicConv2d-292 [-1, 384, 34, 12] 0\n",
" BatchNorm2d-293 [-1, 896, 34, 12] 1,792\n",
"Multi_2D_CNN_block-294 [-1, 896, 34, 12] 0\n",
" ZeroPad2d-295 [-1, 896, 34, 12] 0\n",
" Conv2d-296 [-1, 149, 34, 12] 133,504\n",
" BatchNorm2d-297 [-1, 149, 34, 12] 298\n",
" BasicConv2d-298 [-1, 149, 34, 12] 0\n",
" ZeroPad2d-299 [-1, 896, 34, 12] 0\n",
" Conv2d-300 [-1, 224, 34, 12] 200,704\n",
" BatchNorm2d-301 [-1, 224, 34, 12] 448\n",
" BasicConv2d-302 [-1, 224, 34, 12] 0\n",
" ZeroPad2d-303 [-1, 224, 36, 14] 0\n",
" Conv2d-304 [-1, 448, 34, 12] 903,168\n",
" BatchNorm2d-305 [-1, 448, 34, 12] 896\n",
" BasicConv2d-306 [-1, 448, 34, 12] 0\n",
" ZeroPad2d-307 [-1, 896, 34, 12] 0\n",
" Conv2d-308 [-1, 149, 34, 12] 133,504\n",
" BatchNorm2d-309 [-1, 149, 34, 12] 298\n",
" BasicConv2d-310 [-1, 149, 34, 12] 0\n",
" ZeroPad2d-311 [-1, 149, 36, 14] 0\n",
" Conv2d-312 [-1, 224, 34, 12] 300,384\n",
" BatchNorm2d-313 [-1, 224, 34, 12] 448\n",
" BasicConv2d-314 [-1, 224, 34, 12] 0\n",
" ZeroPad2d-315 [-1, 224, 36, 14] 0\n",
" Conv2d-316 [-1, 448, 34, 12] 903,168\n",
" BatchNorm2d-317 [-1, 448, 34, 12] 896\n",
" BasicConv2d-318 [-1, 448, 34, 12] 0\n",
" BatchNorm2d-319 [-1, 1045, 34, 12] 2,090\n",
"Multi_2D_CNN_block-320 [-1, 1045, 34, 12] 0\n",
" MaxPool2d-321 [-1, 1045, 17, 12] 0\n",
" ZeroPad2d-322 [-1, 1045, 17, 12] 0\n",
" Conv2d-323 [-1, 170, 17, 12] 177,650\n",
" BatchNorm2d-324 [-1, 170, 17, 12] 340\n",
" BasicConv2d-325 [-1, 170, 17, 12] 0\n",
" ZeroPad2d-326 [-1, 1045, 17, 12] 0\n",
" Conv2d-327 [-1, 256, 17, 12] 267,520\n",
" BatchNorm2d-328 [-1, 256, 17, 12] 512\n",
" BasicConv2d-329 [-1, 256, 17, 12] 0\n",
" ZeroPad2d-330 [-1, 256, 19, 14] 0\n",
" Conv2d-331 [-1, 512, 17, 12] 1,179,648\n",
" BatchNorm2d-332 [-1, 512, 17, 12] 1,024\n",
" BasicConv2d-333 [-1, 512, 17, 12] 0\n",
" ZeroPad2d-334 [-1, 1045, 17, 12] 0\n",
" Conv2d-335 [-1, 170, 17, 12] 177,650\n",
" BatchNorm2d-336 [-1, 170, 17, 12] 340\n",
" BasicConv2d-337 [-1, 170, 17, 12] 0\n",
" ZeroPad2d-338 [-1, 170, 19, 14] 0\n",
" Conv2d-339 [-1, 256, 17, 12] 391,680\n",
" BatchNorm2d-340 [-1, 256, 17, 12] 512\n",
" BasicConv2d-341 [-1, 256, 17, 12] 0\n",
" ZeroPad2d-342 [-1, 256, 19, 14] 0\n",
" Conv2d-343 [-1, 512, 17, 12] 1,179,648\n",
" BatchNorm2d-344 [-1, 512, 17, 12] 1,024\n",
" BasicConv2d-345 [-1, 512, 17, 12] 0\n",
" BatchNorm2d-346 [-1, 1194, 17, 12] 2,388\n",
"Multi_2D_CNN_block-347 [-1, 1194, 17, 12] 0\n",
" ZeroPad2d-348 [-1, 1194, 17, 12] 0\n",
" Conv2d-349 [-1, 170, 17, 12] 202,980\n",
" BatchNorm2d-350 [-1, 170, 17, 12] 340\n",
" BasicConv2d-351 [-1, 170, 17, 12] 0\n",
" ZeroPad2d-352 [-1, 1194, 17, 12] 0\n",
" Conv2d-353 [-1, 256, 17, 12] 305,664\n",
" BatchNorm2d-354 [-1, 256, 17, 12] 512\n",
" BasicConv2d-355 [-1, 256, 17, 12] 0\n",
" ZeroPad2d-356 [-1, 256, 19, 14] 0\n",
" Conv2d-357 [-1, 512, 17, 12] 1,179,648\n",
" BatchNorm2d-358 [-1, 512, 17, 12] 1,024\n",
" BasicConv2d-359 [-1, 512, 17, 12] 0\n",
" ZeroPad2d-360 [-1, 1194, 17, 12] 0\n",
" Conv2d-361 [-1, 170, 17, 12] 202,980\n",
" BatchNorm2d-362 [-1, 170, 17, 12] 340\n",
" BasicConv2d-363 [-1, 170, 17, 12] 0\n",
" ZeroPad2d-364 [-1, 170, 19, 14] 0\n",
" Conv2d-365 [-1, 256, 17, 12] 391,680\n",
" BatchNorm2d-366 [-1, 256, 17, 12] 512\n",
" BasicConv2d-367 [-1, 256, 17, 12] 0\n",
" ZeroPad2d-368 [-1, 256, 19, 14] 0\n",
" Conv2d-369 [-1, 512, 17, 12] 1,179,648\n",
" BatchNorm2d-370 [-1, 512, 17, 12] 1,024\n",
" BasicConv2d-371 [-1, 512, 17, 12] 0\n",
" BatchNorm2d-372 [-1, 1194, 17, 12] 2,388\n",
"Multi_2D_CNN_block-373 [-1, 1194, 17, 12] 0\n",
" ZeroPad2d-374 [-1, 1194, 17, 12] 0\n",
" Conv2d-375 [-1, 170, 17, 12] 202,980\n",
" BatchNorm2d-376 [-1, 170, 17, 12] 340\n",
" BasicConv2d-377 [-1, 170, 17, 12] 0\n",
" ZeroPad2d-378 [-1, 1194, 17, 12] 0\n",
" Conv2d-379 [-1, 256, 17, 12] 305,664\n",
" BatchNorm2d-380 [-1, 256, 17, 12] 512\n",
" BasicConv2d-381 [-1, 256, 17, 12] 0\n",
" ZeroPad2d-382 [-1, 256, 19, 14] 0\n",
" Conv2d-383 [-1, 512, 17, 12] 1,179,648\n",
" BatchNorm2d-384 [-1, 512, 17, 12] 1,024\n",
" BasicConv2d-385 [-1, 512, 17, 12] 0\n",
" ZeroPad2d-386 [-1, 1194, 17, 12] 0\n",
" Conv2d-387 [-1, 170, 17, 12] 202,980\n",
" BatchNorm2d-388 [-1, 170, 17, 12] 340\n",
" BasicConv2d-389 [-1, 170, 17, 12] 0\n",
" ZeroPad2d-390 [-1, 170, 19, 14] 0\n",
" Conv2d-391 [-1, 256, 17, 12] 391,680\n",
" BatchNorm2d-392 [-1, 256, 17, 12] 512\n",
" BasicConv2d-393 [-1, 256, 17, 12] 0\n",
" ZeroPad2d-394 [-1, 256, 19, 14] 0\n",
" Conv2d-395 [-1, 512, 17, 12] 1,179,648\n",
" BatchNorm2d-396 [-1, 512, 17, 12] 1,024\n",
" BasicConv2d-397 [-1, 512, 17, 12] 0\n",
" BatchNorm2d-398 [-1, 1194, 17, 12] 2,388\n",
"Multi_2D_CNN_block-399 [-1, 1194, 17, 12] 0\n",
" MaxPool2d-400 [-1, 1194, 8, 12] 0\n",
" ZeroPad2d-401 [-1, 1194, 8, 12] 0\n",
" Conv2d-402 [-1, 256, 8, 12] 305,664\n",
" BatchNorm2d-403 [-1, 256, 8, 12] 512\n",
" BasicConv2d-404 [-1, 256, 8, 12] 0\n",
" ZeroPad2d-405 [-1, 1194, 8, 12] 0\n",
" Conv2d-406 [-1, 384, 8, 12] 458,496\n",
" BatchNorm2d-407 [-1, 384, 8, 12] 768\n",
" BasicConv2d-408 [-1, 384, 8, 12] 0\n",
" ZeroPad2d-409 [-1, 384, 10, 14] 0\n",
" Conv2d-410 [-1, 768, 8, 12] 2,654,208\n",
" BatchNorm2d-411 [-1, 768, 8, 12] 1,536\n",
" BasicConv2d-412 [-1, 768, 8, 12] 0\n",
" ZeroPad2d-413 [-1, 1194, 8, 12] 0\n",
" Conv2d-414 [-1, 256, 8, 12] 305,664\n",
" BatchNorm2d-415 [-1, 256, 8, 12] 512\n",
" BasicConv2d-416 [-1, 256, 8, 12] 0\n",
" ZeroPad2d-417 [-1, 256, 10, 14] 0\n",
" Conv2d-418 [-1, 384, 8, 12] 884,736\n",
" BatchNorm2d-419 [-1, 384, 8, 12] 768\n",
" BasicConv2d-420 [-1, 384, 8, 12] 0\n",
" ZeroPad2d-421 [-1, 384, 10, 14] 0\n",
" Conv2d-422 [-1, 768, 8, 12] 2,654,208\n",
" BatchNorm2d-423 [-1, 768, 8, 12] 1,536\n",
" BasicConv2d-424 [-1, 768, 8, 12] 0\n",
" BatchNorm2d-425 [-1, 1792, 8, 12] 3,584\n",
"Multi_2D_CNN_block-426 [-1, 1792, 8, 12] 0\n",
" ZeroPad2d-427 [-1, 1792, 8, 12] 0\n",
" Conv2d-428 [-1, 298, 8, 12] 534,016\n",
" BatchNorm2d-429 [-1, 298, 8, 12] 596\n",
" BasicConv2d-430 [-1, 298, 8, 12] 0\n",
" ZeroPad2d-431 [-1, 1792, 8, 12] 0\n",
" Conv2d-432 [-1, 448, 8, 12] 802,816\n",
" BatchNorm2d-433 [-1, 448, 8, 12] 896\n",
" BasicConv2d-434 [-1, 448, 8, 12] 0\n",
" ZeroPad2d-435 [-1, 448, 10, 14] 0\n",
" Conv2d-436 [-1, 896, 8, 12] 3,612,672\n",
" BatchNorm2d-437 [-1, 896, 8, 12] 1,792\n",
" BasicConv2d-438 [-1, 896, 8, 12] 0\n",
" ZeroPad2d-439 [-1, 1792, 8, 12] 0\n",
" Conv2d-440 [-1, 298, 8, 12] 534,016\n",
" BatchNorm2d-441 [-1, 298, 8, 12] 596\n",
" BasicConv2d-442 [-1, 298, 8, 12] 0\n",
" ZeroPad2d-443 [-1, 298, 10, 14] 0\n",
" Conv2d-444 [-1, 448, 8, 12] 1,201,536\n",
" BatchNorm2d-445 [-1, 448, 8, 12] 896\n",
" BasicConv2d-446 [-1, 448, 8, 12] 0\n",
" ZeroPad2d-447 [-1, 448, 10, 14] 0\n",
" Conv2d-448 [-1, 896, 8, 12] 3,612,672\n",
" BatchNorm2d-449 [-1, 896, 8, 12] 1,792\n",
" BasicConv2d-450 [-1, 896, 8, 12] 0\n",
" BatchNorm2d-451 [-1, 2090, 8, 12] 4,180\n",
"Multi_2D_CNN_block-452 [-1, 2090, 8, 12] 0\n",
" ZeroPad2d-453 [-1, 2090, 8, 12] 0\n",
" Conv2d-454 [-1, 341, 8, 12] 712,690\n",
" BatchNorm2d-455 [-1, 341, 8, 12] 682\n",
" BasicConv2d-456 [-1, 341, 8, 12] 0\n",
" ZeroPad2d-457 [-1, 2090, 8, 12] 0\n",
" Conv2d-458 [-1, 512, 8, 12] 1,070,080\n",
" BatchNorm2d-459 [-1, 512, 8, 12] 1,024\n",
" BasicConv2d-460 [-1, 512, 8, 12] 0\n",
" ZeroPad2d-461 [-1, 512, 10, 14] 0\n",
" Conv2d-462 [-1, 1024, 8, 12] 4,718,592\n",
" BatchNorm2d-463 [-1, 1024, 8, 12] 2,048\n",
" BasicConv2d-464 [-1, 1024, 8, 12] 0\n",
" ZeroPad2d-465 [-1, 2090, 8, 12] 0\n",
" Conv2d-466 [-1, 341, 8, 12] 712,690\n",
" BatchNorm2d-467 [-1, 341, 8, 12] 682\n",
" BasicConv2d-468 [-1, 341, 8, 12] 0\n",
" ZeroPad2d-469 [-1, 341, 10, 14] 0\n",
" Conv2d-470 [-1, 512, 8, 12] 1,571,328\n",
" BatchNorm2d-471 [-1, 512, 8, 12] 1,024\n",
" BasicConv2d-472 [-1, 512, 8, 12] 0\n",
" ZeroPad2d-473 [-1, 512, 10, 14] 0\n",
" Conv2d-474 [-1, 1024, 8, 12] 4,718,592\n",
" BatchNorm2d-475 [-1, 1024, 8, 12] 2,048\n",
" BasicConv2d-476 [-1, 1024, 8, 12] 0\n",
" BatchNorm2d-477 [-1, 2389, 8, 12] 4,778\n",
"Multi_2D_CNN_block-478 [-1, 2389, 8, 12] 0\n",
"AdaptiveAvgPool2d-479 [-1, 2389, 1, 1] 0\n",
" Flatten-480 [-1, 2389] 0\n",
" Dropout-481 [-1, 2389] 0\n",
" Linear-482 [-1, 1] 2,390\n",
"================================================================\n",
"Total params: 49,721,334\n",
"Trainable params: 49,721,334\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n",
"Input size (MB): 0.11\n",
"Forward/backward pass size (MB): 884.62\n",
"Params size (MB): 189.67\n",
"Estimated Total Size (MB): 1074.40\n",
"----------------------------------------------------------------\n"
]
}
],
"source": [
"summary(Inception_Tabular_15_3(), input_size=(1, 2500, 12))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" ZeroPad2d-1 [-1, 1, 2506, 14] 0\n",
" Conv2d-2 [-1, 64, 1250, 12] 1,344\n",
" BatchNorm2d-3 [-1, 64, 1250, 12] 128\n",
" BasicConv2d-4 [-1, 64, 1250, 12] 0\n",
" ZeroPad2d-5 [-1, 64, 1250, 12] 0\n",
" Conv2d-6 [-1, 21, 1250, 12] 1,344\n",
" BatchNorm2d-7 [-1, 21, 1250, 12] 42\n",
" BasicConv2d-8 [-1, 21, 1250, 12] 0\n",
" ZeroPad2d-9 [-1, 64, 1250, 12] 0\n",
" Conv2d-10 [-1, 32, 1250, 12] 2,048\n",
" BatchNorm2d-11 [-1, 32, 1250, 12] 64\n",
" BasicConv2d-12 [-1, 32, 1250, 12] 0\n",
" ZeroPad2d-13 [-1, 32, 1252, 14] 0\n",
" Conv2d-14 [-1, 64, 1250, 12] 18,432\n",
" BatchNorm2d-15 [-1, 64, 1250, 12] 128\n",
" BasicConv2d-16 [-1, 64, 1250, 12] 0\n",
" ZeroPad2d-17 [-1, 64, 1250, 12] 0\n",
" Conv2d-18 [-1, 21, 1250, 12] 1,344\n",
" BatchNorm2d-19 [-1, 21, 1250, 12] 42\n",
" BasicConv2d-20 [-1, 21, 1250, 12] 0\n",
" ZeroPad2d-21 [-1, 21, 1252, 14] 0\n",
" Conv2d-22 [-1, 32, 1250, 12] 6,048\n",
" BatchNorm2d-23 [-1, 32, 1250, 12] 64\n",
" BasicConv2d-24 [-1, 32, 1250, 12] 0\n",
" ZeroPad2d-25 [-1, 32, 1252, 14] 0\n",
" Conv2d-26 [-1, 64, 1250, 12] 18,432\n",
" BatchNorm2d-27 [-1, 64, 1250, 12] 128\n",
" BasicConv2d-28 [-1, 64, 1250, 12] 0\n",
" BatchNorm2d-29 [-1, 149, 1250, 12] 298\n",
"Multi_2D_CNN_block-30 [-1, 149, 1250, 12] 0\n",
" ZeroPad2d-31 [-1, 149, 1250, 12] 0\n",
" Conv2d-32 [-1, 21, 1250, 12] 3,129\n",
" BatchNorm2d-33 [-1, 21, 1250, 12] 42\n",
" BasicConv2d-34 [-1, 21, 1250, 12] 0\n",
" ZeroPad2d-35 [-1, 149, 1250, 12] 0\n",
" Conv2d-36 [-1, 32, 1250, 12] 4,768\n",
" BatchNorm2d-37 [-1, 32, 1250, 12] 64\n",
" BasicConv2d-38 [-1, 32, 1250, 12] 0\n",
" ZeroPad2d-39 [-1, 32, 1252, 14] 0\n",
" Conv2d-40 [-1, 64, 1250, 12] 18,432\n",
" BatchNorm2d-41 [-1, 64, 1250, 12] 128\n",
" BasicConv2d-42 [-1, 64, 1250, 12] 0\n",
" ZeroPad2d-43 [-1, 149, 1250, 12] 0\n",
" Conv2d-44 [-1, 21, 1250, 12] 3,129\n",
" BatchNorm2d-45 [-1, 21, 1250, 12] 42\n",
" BasicConv2d-46 [-1, 21, 1250, 12] 0\n",
" ZeroPad2d-47 [-1, 21, 1252, 14] 0\n",
" Conv2d-48 [-1, 32, 1250, 12] 6,048\n",
" BatchNorm2d-49 [-1, 32, 1250, 12] 64\n",
" BasicConv2d-50 [-1, 32, 1250, 12] 0\n",
" ZeroPad2d-51 [-1, 32, 1252, 14] 0\n",
" Conv2d-52 [-1, 64, 1250, 12] 18,432\n",
" BatchNorm2d-53 [-1, 64, 1250, 12] 128\n",
" BasicConv2d-54 [-1, 64, 1250, 12] 0\n",
" BatchNorm2d-55 [-1, 149, 1250, 12] 298\n",
"Multi_2D_CNN_block-56 [-1, 149, 1250, 12] 0\n",
" MaxPool2d-57 [-1, 149, 416, 12] 0\n",
" ZeroPad2d-58 [-1, 149, 416, 12] 0\n",
" Conv2d-59 [-1, 32, 416, 12] 4,768\n",
" BatchNorm2d-60 [-1, 32, 416, 12] 64\n",
" BasicConv2d-61 [-1, 32, 416, 12] 0\n",
" ZeroPad2d-62 [-1, 149, 416, 12] 0\n",
" Conv2d-63 [-1, 48, 416, 12] 7,152\n",
" BatchNorm2d-64 [-1, 48, 416, 12] 96\n",
" BasicConv2d-65 [-1, 48, 416, 12] 0\n",
" ZeroPad2d-66 [-1, 48, 418, 14] 0\n",
" Conv2d-67 [-1, 96, 416, 12] 41,472\n",
" BatchNorm2d-68 [-1, 96, 416, 12] 192\n",
" BasicConv2d-69 [-1, 96, 416, 12] 0\n",
" ZeroPad2d-70 [-1, 149, 416, 12] 0\n",
" Conv2d-71 [-1, 32, 416, 12] 4,768\n",
" BatchNorm2d-72 [-1, 32, 416, 12] 64\n",
" BasicConv2d-73 [-1, 32, 416, 12] 0\n",
" ZeroPad2d-74 [-1, 32, 418, 14] 0\n",
" Conv2d-75 [-1, 48, 416, 12] 13,824\n",
" BatchNorm2d-76 [-1, 48, 416, 12] 96\n",
" BasicConv2d-77 [-1, 48, 416, 12] 0\n",
" ZeroPad2d-78 [-1, 48, 418, 14] 0\n",
" Conv2d-79 [-1, 96, 416, 12] 41,472\n",
" BatchNorm2d-80 [-1, 96, 416, 12] 192\n",
" BasicConv2d-81 [-1, 96, 416, 12] 0\n",
" BatchNorm2d-82 [-1, 224, 416, 12] 448\n",
"Multi_2D_CNN_block-83 [-1, 224, 416, 12] 0\n",
" ZeroPad2d-84 [-1, 224, 416, 12] 0\n",
" Conv2d-85 [-1, 32, 416, 12] 7,168\n",
" BatchNorm2d-86 [-1, 32, 416, 12] 64\n",
" BasicConv2d-87 [-1, 32, 416, 12] 0\n",
" ZeroPad2d-88 [-1, 224, 416, 12] 0\n",
" Conv2d-89 [-1, 48, 416, 12] 10,752\n",
" BatchNorm2d-90 [-1, 48, 416, 12] 96\n",
" BasicConv2d-91 [-1, 48, 416, 12] 0\n",
" ZeroPad2d-92 [-1, 48, 418, 14] 0\n",
" Conv2d-93 [-1, 96, 416, 12] 41,472\n",
" BatchNorm2d-94 [-1, 96, 416, 12] 192\n",
" BasicConv2d-95 [-1, 96, 416, 12] 0\n",
" ZeroPad2d-96 [-1, 224, 416, 12] 0\n",
" Conv2d-97 [-1, 32, 416, 12] 7,168\n",
" BatchNorm2d-98 [-1, 32, 416, 12] 64\n",
" BasicConv2d-99 [-1, 32, 416, 12] 0\n",
" ZeroPad2d-100 [-1, 32, 418, 14] 0\n",
" Conv2d-101 [-1, 48, 416, 12] 13,824\n",
" BatchNorm2d-102 [-1, 48, 416, 12] 96\n",
" BasicConv2d-103 [-1, 48, 416, 12] 0\n",
" ZeroPad2d-104 [-1, 48, 418, 14] 0\n",
" Conv2d-105 [-1, 96, 416, 12] 41,472\n",
" BatchNorm2d-106 [-1, 96, 416, 12] 192\n",
" BasicConv2d-107 [-1, 96, 416, 12] 0\n",
" BatchNorm2d-108 [-1, 224, 416, 12] 448\n",
"Multi_2D_CNN_block-109 [-1, 224, 416, 12] 0\n",
" MaxPool2d-110 [-1, 224, 138, 12] 0\n",
" ZeroPad2d-111 [-1, 224, 138, 12] 0\n",
" Conv2d-112 [-1, 42, 138, 12] 9,408\n",
" BatchNorm2d-113 [-1, 42, 138, 12] 84\n",
" BasicConv2d-114 [-1, 42, 138, 12] 0\n",
" ZeroPad2d-115 [-1, 224, 138, 12] 0\n",
" Conv2d-116 [-1, 64, 138, 12] 14,336\n",
" BatchNorm2d-117 [-1, 64, 138, 12] 128\n",
" BasicConv2d-118 [-1, 64, 138, 12] 0\n",
" ZeroPad2d-119 [-1, 64, 140, 14] 0\n",
" Conv2d-120 [-1, 128, 138, 12] 73,728\n",
" BatchNorm2d-121 [-1, 128, 138, 12] 256\n",
" BasicConv2d-122 [-1, 128, 138, 12] 0\n",
" ZeroPad2d-123 [-1, 224, 138, 12] 0\n",
" Conv2d-124 [-1, 42, 138, 12] 9,408\n",
" BatchNorm2d-125 [-1, 42, 138, 12] 84\n",
" BasicConv2d-126 [-1, 42, 138, 12] 0\n",
" ZeroPad2d-127 [-1, 42, 140, 14] 0\n",
" Conv2d-128 [-1, 64, 138, 12] 24,192\n",
" BatchNorm2d-129 [-1, 64, 138, 12] 128\n",
" BasicConv2d-130 [-1, 64, 138, 12] 0\n",
" ZeroPad2d-131 [-1, 64, 140, 14] 0\n",
" Conv2d-132 [-1, 128, 138, 12] 73,728\n",
" BatchNorm2d-133 [-1, 128, 138, 12] 256\n",
" BasicConv2d-134 [-1, 128, 138, 12] 0\n",
" BatchNorm2d-135 [-1, 298, 138, 12] 596\n",
"Multi_2D_CNN_block-136 [-1, 298, 138, 12] 0\n",
" ZeroPad2d-137 [-1, 298, 138, 12] 0\n",
" Conv2d-138 [-1, 42, 138, 12] 12,516\n",
" BatchNorm2d-139 [-1, 42, 138, 12] 84\n",
" BasicConv2d-140 [-1, 42, 138, 12] 0\n",
" ZeroPad2d-141 [-1, 298, 138, 12] 0\n",
" Conv2d-142 [-1, 64, 138, 12] 19,072\n",
" BatchNorm2d-143 [-1, 64, 138, 12] 128\n",
" BasicConv2d-144 [-1, 64, 138, 12] 0\n",
" ZeroPad2d-145 [-1, 64, 140, 14] 0\n",
" Conv2d-146 [-1, 128, 138, 12] 73,728\n",
" BatchNorm2d-147 [-1, 128, 138, 12] 256\n",
" BasicConv2d-148 [-1, 128, 138, 12] 0\n",
" ZeroPad2d-149 [-1, 298, 138, 12] 0\n",
" Conv2d-150 [-1, 42, 138, 12] 12,516\n",
" BatchNorm2d-151 [-1, 42, 138, 12] 84\n",
" BasicConv2d-152 [-1, 42, 138, 12] 0\n",
" ZeroPad2d-153 [-1, 42, 140, 14] 0\n",
" Conv2d-154 [-1, 64, 138, 12] 24,192\n",
" BatchNorm2d-155 [-1, 64, 138, 12] 128\n",
" BasicConv2d-156 [-1, 64, 138, 12] 0\n",
" ZeroPad2d-157 [-1, 64, 140, 14] 0\n",
" Conv2d-158 [-1, 128, 138, 12] 73,728\n",
" BatchNorm2d-159 [-1, 128, 138, 12] 256\n",
" BasicConv2d-160 [-1, 128, 138, 12] 0\n",
" BatchNorm2d-161 [-1, 298, 138, 12] 596\n",
"Multi_2D_CNN_block-162 [-1, 298, 138, 12] 0\n",
" MaxPool2d-163 [-1, 298, 69, 12] 0\n",
" ZeroPad2d-164 [-1, 298, 69, 12] 0\n",
" Conv2d-165 [-1, 64, 69, 12] 19,072\n",
" BatchNorm2d-166 [-1, 64, 69, 12] 128\n",
" BasicConv2d-167 [-1, 64, 69, 12] 0\n",
" ZeroPad2d-168 [-1, 298, 69, 12] 0\n",
" Conv2d-169 [-1, 96, 69, 12] 28,608\n",
" BatchNorm2d-170 [-1, 96, 69, 12] 192\n",
" BasicConv2d-171 [-1, 96, 69, 12] 0\n",
" ZeroPad2d-172 [-1, 96, 71, 14] 0\n",
" Conv2d-173 [-1, 192, 69, 12] 165,888\n",
" BatchNorm2d-174 [-1, 192, 69, 12] 384\n",
" BasicConv2d-175 [-1, 192, 69, 12] 0\n",
" ZeroPad2d-176 [-1, 298, 69, 12] 0\n",
" Conv2d-177 [-1, 64, 69, 12] 19,072\n",
" BatchNorm2d-178 [-1, 64, 69, 12] 128\n",
" BasicConv2d-179 [-1, 64, 69, 12] 0\n",
" ZeroPad2d-180 [-1, 64, 71, 14] 0\n",
" Conv2d-181 [-1, 96, 69, 12] 55,296\n",
" BatchNorm2d-182 [-1, 96, 69, 12] 192\n",
" BasicConv2d-183 [-1, 96, 69, 12] 0\n",
" ZeroPad2d-184 [-1, 96, 71, 14] 0\n",
" Conv2d-185 [-1, 192, 69, 12] 165,888\n",
" BatchNorm2d-186 [-1, 192, 69, 12] 384\n",
" BasicConv2d-187 [-1, 192, 69, 12] 0\n",
" BatchNorm2d-188 [-1, 448, 69, 12] 896\n",
"Multi_2D_CNN_block-189 [-1, 448, 69, 12] 0\n",
" ZeroPad2d-190 [-1, 448, 69, 12] 0\n",
" Conv2d-191 [-1, 64, 69, 12] 28,672\n",
" BatchNorm2d-192 [-1, 64, 69, 12] 128\n",
" BasicConv2d-193 [-1, 64, 69, 12] 0\n",
" ZeroPad2d-194 [-1, 448, 69, 12] 0\n",
" Conv2d-195 [-1, 96, 69, 12] 43,008\n",
" BatchNorm2d-196 [-1, 96, 69, 12] 192\n",
" BasicConv2d-197 [-1, 96, 69, 12] 0\n",
" ZeroPad2d-198 [-1, 96, 71, 14] 0\n",
" Conv2d-199 [-1, 192, 69, 12] 165,888\n",
" BatchNorm2d-200 [-1, 192, 69, 12] 384\n",
" BasicConv2d-201 [-1, 192, 69, 12] 0\n",
" ZeroPad2d-202 [-1, 448, 69, 12] 0\n",
" Conv2d-203 [-1, 64, 69, 12] 28,672\n",
" BatchNorm2d-204 [-1, 64, 69, 12] 128\n",
" BasicConv2d-205 [-1, 64, 69, 12] 0\n",
" ZeroPad2d-206 [-1, 64, 71, 14] 0\n",
" Conv2d-207 [-1, 96, 69, 12] 55,296\n",
" BatchNorm2d-208 [-1, 96, 69, 12] 192\n",
" BasicConv2d-209 [-1, 96, 69, 12] 0\n",
" ZeroPad2d-210 [-1, 96, 71, 14] 0\n",
" Conv2d-211 [-1, 192, 69, 12] 165,888\n",
" BatchNorm2d-212 [-1, 192, 69, 12] 384\n",
" BasicConv2d-213 [-1, 192, 69, 12] 0\n",
" BatchNorm2d-214 [-1, 448, 69, 12] 896\n",
"Multi_2D_CNN_block-215 [-1, 448, 69, 12] 0\n",
" ZeroPad2d-216 [-1, 448, 69, 12] 0\n",
" Conv2d-217 [-1, 85, 69, 12] 38,080\n",
" BatchNorm2d-218 [-1, 85, 69, 12] 170\n",
" BasicConv2d-219 [-1, 85, 69, 12] 0\n",
" ZeroPad2d-220 [-1, 448, 69, 12] 0\n",
" Conv2d-221 [-1, 128, 69, 12] 57,344\n",
" BatchNorm2d-222 [-1, 128, 69, 12] 256\n",
" BasicConv2d-223 [-1, 128, 69, 12] 0\n",
" ZeroPad2d-224 [-1, 128, 71, 14] 0\n",
" Conv2d-225 [-1, 256, 69, 12] 294,912\n",
" BatchNorm2d-226 [-1, 256, 69, 12] 512\n",
" BasicConv2d-227 [-1, 256, 69, 12] 0\n",
" ZeroPad2d-228 [-1, 448, 69, 12] 0\n",
" Conv2d-229 [-1, 85, 69, 12] 38,080\n",
" BatchNorm2d-230 [-1, 85, 69, 12] 170\n",
" BasicConv2d-231 [-1, 85, 69, 12] 0\n",
" ZeroPad2d-232 [-1, 85, 71, 14] 0\n",
" Conv2d-233 [-1, 128, 69, 12] 97,920\n",
" BatchNorm2d-234 [-1, 128, 69, 12] 256\n",
" BasicConv2d-235 [-1, 128, 69, 12] 0\n",
" ZeroPad2d-236 [-1, 128, 71, 14] 0\n",
" Conv2d-237 [-1, 256, 69, 12] 294,912\n",
" BatchNorm2d-238 [-1, 256, 69, 12] 512\n",
" BasicConv2d-239 [-1, 256, 69, 12] 0\n",
" BatchNorm2d-240 [-1, 597, 69, 12] 1,194\n",
"Multi_2D_CNN_block-241 [-1, 597, 69, 12] 0\n",
" MaxPool2d-242 [-1, 597, 34, 12] 0\n",
" ZeroPad2d-243 [-1, 597, 34, 12] 0\n",
" Conv2d-244 [-1, 106, 34, 12] 63,282\n",
" BatchNorm2d-245 [-1, 106, 34, 12] 212\n",
" BasicConv2d-246 [-1, 106, 34, 12] 0\n",
" ZeroPad2d-247 [-1, 597, 34, 12] 0\n",
" Conv2d-248 [-1, 160, 34, 12] 95,520\n",
" BatchNorm2d-249 [-1, 160, 34, 12] 320\n",
" BasicConv2d-250 [-1, 160, 34, 12] 0\n",
" ZeroPad2d-251 [-1, 160, 36, 14] 0\n",
" Conv2d-252 [-1, 320, 34, 12] 460,800\n",
" BatchNorm2d-253 [-1, 320, 34, 12] 640\n",
" BasicConv2d-254 [-1, 320, 34, 12] 0\n",
" ZeroPad2d-255 [-1, 597, 34, 12] 0\n",
" Conv2d-256 [-1, 106, 34, 12] 63,282\n",
" BatchNorm2d-257 [-1, 106, 34, 12] 212\n",
" BasicConv2d-258 [-1, 106, 34, 12] 0\n",
" ZeroPad2d-259 [-1, 106, 36, 14] 0\n",
" Conv2d-260 [-1, 160, 34, 12] 152,640\n",
" BatchNorm2d-261 [-1, 160, 34, 12] 320\n",
" BasicConv2d-262 [-1, 160, 34, 12] 0\n",
" ZeroPad2d-263 [-1, 160, 36, 14] 0\n",
" Conv2d-264 [-1, 320, 34, 12] 460,800\n",
" BatchNorm2d-265 [-1, 320, 34, 12] 640\n",
" BasicConv2d-266 [-1, 320, 34, 12] 0\n",
" BatchNorm2d-267 [-1, 746, 34, 12] 1,492\n",
"Multi_2D_CNN_block-268 [-1, 746, 34, 12] 0\n",
" ZeroPad2d-269 [-1, 746, 34, 12] 0\n",
" Conv2d-270 [-1, 128, 34, 12] 95,488\n",
" BatchNorm2d-271 [-1, 128, 34, 12] 256\n",
" BasicConv2d-272 [-1, 128, 34, 12] 0\n",
" ZeroPad2d-273 [-1, 746, 34, 12] 0\n",
" Conv2d-274 [-1, 192, 34, 12] 143,232\n",
" BatchNorm2d-275 [-1, 192, 34, 12] 384\n",
" BasicConv2d-276 [-1, 192, 34, 12] 0\n",
" ZeroPad2d-277 [-1, 192, 36, 14] 0\n",
" Conv2d-278 [-1, 384, 34, 12] 663,552\n",
" BatchNorm2d-279 [-1, 384, 34, 12] 768\n",
" BasicConv2d-280 [-1, 384, 34, 12] 0\n",
" ZeroPad2d-281 [-1, 746, 34, 12] 0\n",
" Conv2d-282 [-1, 128, 34, 12] 95,488\n",
" BatchNorm2d-283 [-1, 128, 34, 12] 256\n",
" BasicConv2d-284 [-1, 128, 34, 12] 0\n",
" ZeroPad2d-285 [-1, 128, 36, 14] 0\n",
" Conv2d-286 [-1, 192, 34, 12] 221,184\n",
" BatchNorm2d-287 [-1, 192, 34, 12] 384\n",
" BasicConv2d-288 [-1, 192, 34, 12] 0\n",
" ZeroPad2d-289 [-1, 192, 36, 14] 0\n",
" Conv2d-290 [-1, 384, 34, 12] 663,552\n",
" BatchNorm2d-291 [-1, 384, 34, 12] 768\n",
" BasicConv2d-292 [-1, 384, 34, 12] 0\n",
" BatchNorm2d-293 [-1, 896, 34, 12] 1,792\n",
"Multi_2D_CNN_block-294 [-1, 896, 34, 12] 0\n",
" ZeroPad2d-295 [-1, 896, 34, 12] 0\n",
" Conv2d-296 [-1, 149, 34, 12] 133,504\n",
" BatchNorm2d-297 [-1, 149, 34, 12] 298\n",
" BasicConv2d-298 [-1, 149, 34, 12] 0\n",
" ZeroPad2d-299 [-1, 896, 34, 12] 0\n",
" Conv2d-300 [-1, 224, 34, 12] 200,704\n",
" BatchNorm2d-301 [-1, 224, 34, 12] 448\n",
" BasicConv2d-302 [-1, 224, 34, 12] 0\n",
" ZeroPad2d-303 [-1, 224, 36, 14] 0\n",
" Conv2d-304 [-1, 448, 34, 12] 903,168\n",
" BatchNorm2d-305 [-1, 448, 34, 12] 896\n",
" BasicConv2d-306 [-1, 448, 34, 12] 0\n",
" ZeroPad2d-307 [-1, 896, 34, 12] 0\n",
" Conv2d-308 [-1, 149, 34, 12] 133,504\n",
" BatchNorm2d-309 [-1, 149, 34, 12] 298\n",
" BasicConv2d-310 [-1, 149, 34, 12] 0\n",
" ZeroPad2d-311 [-1, 149, 36, 14] 0\n",
" Conv2d-312 [-1, 224, 34, 12] 300,384\n",
" BatchNorm2d-313 [-1, 224, 34, 12] 448\n",
" BasicConv2d-314 [-1, 224, 34, 12] 0\n",
" ZeroPad2d-315 [-1, 224, 36, 14] 0\n",
" Conv2d-316 [-1, 448, 34, 12] 903,168\n",
" BatchNorm2d-317 [-1, 448, 34, 12] 896\n",
" BasicConv2d-318 [-1, 448, 34, 12] 0\n",
" BatchNorm2d-319 [-1, 1045, 34, 12] 2,090\n",
"Multi_2D_CNN_block-320 [-1, 1045, 34, 12] 0\n",
" MaxPool2d-321 [-1, 1045, 17, 12] 0\n",
" ZeroPad2d-322 [-1, 1045, 17, 12] 0\n",
" Conv2d-323 [-1, 170, 17, 12] 177,650\n",
" BatchNorm2d-324 [-1, 170, 17, 12] 340\n",
" BasicConv2d-325 [-1, 170, 17, 12] 0\n",
" ZeroPad2d-326 [-1, 1045, 17, 12] 0\n",
" Conv2d-327 [-1, 256, 17, 12] 267,520\n",
" BatchNorm2d-328 [-1, 256, 17, 12] 512\n",
" BasicConv2d-329 [-1, 256, 17, 12] 0\n",
" ZeroPad2d-330 [-1, 256, 19, 14] 0\n",
" Conv2d-331 [-1, 512, 17, 12] 1,179,648\n",
" BatchNorm2d-332 [-1, 512, 17, 12] 1,024\n",
" BasicConv2d-333 [-1, 512, 17, 12] 0\n",
" ZeroPad2d-334 [-1, 1045, 17, 12] 0\n",
" Conv2d-335 [-1, 170, 17, 12] 177,650\n",
" BatchNorm2d-336 [-1, 170, 17, 12] 340\n",
" BasicConv2d-337 [-1, 170, 17, 12] 0\n",
" ZeroPad2d-338 [-1, 170, 19, 14] 0\n",
" Conv2d-339 [-1, 256, 17, 12] 391,680\n",
" BatchNorm2d-340 [-1, 256, 17, 12] 512\n",
" BasicConv2d-341 [-1, 256, 17, 12] 0\n",
" ZeroPad2d-342 [-1, 256, 19, 14] 0\n",
" Conv2d-343 [-1, 512, 17, 12] 1,179,648\n",
" BatchNorm2d-344 [-1, 512, 17, 12] 1,024\n",
" BasicConv2d-345 [-1, 512, 17, 12] 0\n",
" BatchNorm2d-346 [-1, 1194, 17, 12] 2,388\n",
"Multi_2D_CNN_block-347 [-1, 1194, 17, 12] 0\n",
" ZeroPad2d-348 [-1, 1194, 17, 12] 0\n",
" Conv2d-349 [-1, 170, 17, 12] 202,980\n",
" BatchNorm2d-350 [-1, 170, 17, 12] 340\n",
" BasicConv2d-351 [-1, 170, 17, 12] 0\n",
" ZeroPad2d-352 [-1, 1194, 17, 12] 0\n",
" Conv2d-353 [-1, 256, 17, 12] 305,664\n",
" BatchNorm2d-354 [-1, 256, 17, 12] 512\n",
" BasicConv2d-355 [-1, 256, 17, 12] 0\n",
" ZeroPad2d-356 [-1, 256, 19, 14] 0\n",
" Conv2d-357 [-1, 512, 17, 12] 1,179,648\n",
" BatchNorm2d-358 [-1, 512, 17, 12] 1,024\n",
" BasicConv2d-359 [-1, 512, 17, 12] 0\n",
" ZeroPad2d-360 [-1, 1194, 17, 12] 0\n",
" Conv2d-361 [-1, 170, 17, 12] 202,980\n",
" BatchNorm2d-362 [-1, 170, 17, 12] 340\n",
" BasicConv2d-363 [-1, 170, 17, 12] 0\n",
" ZeroPad2d-364 [-1, 170, 19, 14] 0\n",
" Conv2d-365 [-1, 256, 17, 12] 391,680\n",
" BatchNorm2d-366 [-1, 256, 17, 12] 512\n",
" BasicConv2d-367 [-1, 256, 17, 12] 0\n",
" ZeroPad2d-368 [-1, 256, 19, 14] 0\n",
" Conv2d-369 [-1, 512, 17, 12] 1,179,648\n",
" BatchNorm2d-370 [-1, 512, 17, 12] 1,024\n",
" BasicConv2d-371 [-1, 512, 17, 12] 0\n",
" BatchNorm2d-372 [-1, 1194, 17, 12] 2,388\n",
"Multi_2D_CNN_block-373 [-1, 1194, 17, 12] 0\n",
" ZeroPad2d-374 [-1, 1194, 17, 12] 0\n",
" Conv2d-375 [-1, 170, 17, 12] 202,980\n",
" BatchNorm2d-376 [-1, 170, 17, 12] 340\n",
" BasicConv2d-377 [-1, 170, 17, 12] 0\n",
" ZeroPad2d-378 [-1, 1194, 17, 12] 0\n",
" Conv2d-379 [-1, 256, 17, 12] 305,664\n",
" BatchNorm2d-380 [-1, 256, 17, 12] 512\n",
" BasicConv2d-381 [-1, 256, 17, 12] 0\n",
" ZeroPad2d-382 [-1, 256, 19, 14] 0\n",
" Conv2d-383 [-1, 512, 17, 12] 1,179,648\n",
" BatchNorm2d-384 [-1, 512, 17, 12] 1,024\n",
" BasicConv2d-385 [-1, 512, 17, 12] 0\n",
" ZeroPad2d-386 [-1, 1194, 17, 12] 0\n",
" Conv2d-387 [-1, 170, 17, 12] 202,980\n",
" BatchNorm2d-388 [-1, 170, 17, 12] 340\n",
" BasicConv2d-389 [-1, 170, 17, 12] 0\n",
" ZeroPad2d-390 [-1, 170, 19, 14] 0\n",
" Conv2d-391 [-1, 256, 17, 12] 391,680\n",
" BatchNorm2d-392 [-1, 256, 17, 12] 512\n",
" BasicConv2d-393 [-1, 256, 17, 12] 0\n",
" ZeroPad2d-394 [-1, 256, 19, 14] 0\n",
" Conv2d-395 [-1, 512, 17, 12] 1,179,648\n",
" BatchNorm2d-396 [-1, 512, 17, 12] 1,024\n",
" BasicConv2d-397 [-1, 512, 17, 12] 0\n",
" BatchNorm2d-398 [-1, 1194, 17, 12] 2,388\n",
"Multi_2D_CNN_block-399 [-1, 1194, 17, 12] 0\n",
" MaxPool2d-400 [-1, 1194, 8, 12] 0\n",
" ZeroPad2d-401 [-1, 1194, 8, 12] 0\n",
" Conv2d-402 [-1, 256, 8, 12] 305,664\n",
" BatchNorm2d-403 [-1, 256, 8, 12] 512\n",
" BasicConv2d-404 [-1, 256, 8, 12] 0\n",
" ZeroPad2d-405 [-1, 1194, 8, 12] 0\n",
" Conv2d-406 [-1, 384, 8, 12] 458,496\n",
" BatchNorm2d-407 [-1, 384, 8, 12] 768\n",
" BasicConv2d-408 [-1, 384, 8, 12] 0\n",
" ZeroPad2d-409 [-1, 384, 10, 14] 0\n",
" Conv2d-410 [-1, 768, 8, 12] 2,654,208\n",
" BatchNorm2d-411 [-1, 768, 8, 12] 1,536\n",
" BasicConv2d-412 [-1, 768, 8, 12] 0\n",
" ZeroPad2d-413 [-1, 1194, 8, 12] 0\n",
" Conv2d-414 [-1, 256, 8, 12] 305,664\n",
" BatchNorm2d-415 [-1, 256, 8, 12] 512\n",
" BasicConv2d-416 [-1, 256, 8, 12] 0\n",
" ZeroPad2d-417 [-1, 256, 10, 14] 0\n",
" Conv2d-418 [-1, 384, 8, 12] 884,736\n",
" BatchNorm2d-419 [-1, 384, 8, 12] 768\n",
" BasicConv2d-420 [-1, 384, 8, 12] 0\n",
" ZeroPad2d-421 [-1, 384, 10, 14] 0\n",
" Conv2d-422 [-1, 768, 8, 12] 2,654,208\n",
" BatchNorm2d-423 [-1, 768, 8, 12] 1,536\n",
" BasicConv2d-424 [-1, 768, 8, 12] 0\n",
" BatchNorm2d-425 [-1, 1792, 8, 12] 3,584\n",
"Multi_2D_CNN_block-426 [-1, 1792, 8, 12] 0\n",
" ZeroPad2d-427 [-1, 1792, 8, 12] 0\n",
" Conv2d-428 [-1, 298, 8, 12] 534,016\n",
" BatchNorm2d-429 [-1, 298, 8, 12] 596\n",
" BasicConv2d-430 [-1, 298, 8, 12] 0\n",
" ZeroPad2d-431 [-1, 1792, 8, 12] 0\n",
" Conv2d-432 [-1, 448, 8, 12] 802,816\n",
" BatchNorm2d-433 [-1, 448, 8, 12] 896\n",
" BasicConv2d-434 [-1, 448, 8, 12] 0\n",
" ZeroPad2d-435 [-1, 448, 10, 14] 0\n",
" Conv2d-436 [-1, 896, 8, 12] 3,612,672\n",
" BatchNorm2d-437 [-1, 896, 8, 12] 1,792\n",
" BasicConv2d-438 [-1, 896, 8, 12] 0\n",
" ZeroPad2d-439 [-1, 1792, 8, 12] 0\n",
" Conv2d-440 [-1, 298, 8, 12] 534,016\n",
" BatchNorm2d-441 [-1, 298, 8, 12] 596\n",
" BasicConv2d-442 [-1, 298, 8, 12] 0\n",
" ZeroPad2d-443 [-1, 298, 10, 14] 0\n",
" Conv2d-444 [-1, 448, 8, 12] 1,201,536\n",
" BatchNorm2d-445 [-1, 448, 8, 12] 896\n",
" BasicConv2d-446 [-1, 448, 8, 12] 0\n",
" ZeroPad2d-447 [-1, 448, 10, 14] 0\n",
" Conv2d-448 [-1, 896, 8, 12] 3,612,672\n",
" BatchNorm2d-449 [-1, 896, 8, 12] 1,792\n",
" BasicConv2d-450 [-1, 896, 8, 12] 0\n",
" BatchNorm2d-451 [-1, 2090, 8, 12] 4,180\n",
"Multi_2D_CNN_block-452 [-1, 2090, 8, 12] 0\n",
" ZeroPad2d-453 [-1, 2090, 8, 12] 0\n",
" Conv2d-454 [-1, 341, 8, 12] 712,690\n",
" BatchNorm2d-455 [-1, 341, 8, 12] 682\n",
" BasicConv2d-456 [-1, 341, 8, 12] 0\n",
" ZeroPad2d-457 [-1, 2090, 8, 12] 0\n",
" Conv2d-458 [-1, 512, 8, 12] 1,070,080\n",
" BatchNorm2d-459 [-1, 512, 8, 12] 1,024\n",
" BasicConv2d-460 [-1, 512, 8, 12] 0\n",
" ZeroPad2d-461 [-1, 512, 10, 14] 0\n",
" Conv2d-462 [-1, 1024, 8, 12] 4,718,592\n",
" BatchNorm2d-463 [-1, 1024, 8, 12] 2,048\n",
" BasicConv2d-464 [-1, 1024, 8, 12] 0\n",
" ZeroPad2d-465 [-1, 2090, 8, 12] 0\n",
" Conv2d-466 [-1, 341, 8, 12] 712,690\n",
" BatchNorm2d-467 [-1, 341, 8, 12] 682\n",
" BasicConv2d-468 [-1, 341, 8, 12] 0\n",
" ZeroPad2d-469 [-1, 341, 10, 14] 0\n",
" Conv2d-470 [-1, 512, 8, 12] 1,571,328\n",
" BatchNorm2d-471 [-1, 512, 8, 12] 1,024\n",
" BasicConv2d-472 [-1, 512, 8, 12] 0\n",
" ZeroPad2d-473 [-1, 512, 10, 14] 0\n",
" Conv2d-474 [-1, 1024, 8, 12] 4,718,592\n",
" BatchNorm2d-475 [-1, 1024, 8, 12] 2,048\n",
" BasicConv2d-476 [-1, 1024, 8, 12] 0\n",
" BatchNorm2d-477 [-1, 2389, 8, 12] 4,778\n",
"Multi_2D_CNN_block-478 [-1, 2389, 8, 12] 0\n",
"AdaptiveAvgPool2d-479 [-1, 2389, 1, 1] 0\n",
" Flatten-480 [-1, 2389] 0\n",
" Dropout-481 [-1, 2389] 0\n",
" Linear-482 [-1, 1] 2,390\n",
"================================================================\n",
"Total params: 49,719,798\n",
"Trainable params: 49,719,798\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n",
"Input size (MB): 0.11\n",
"Forward/backward pass size (MB): 884.62\n",
"Params size (MB): 189.67\n",
"Estimated Total Size (MB): 1074.40\n",
"----------------------------------------------------------------\n"
]
}
],
"source": [
"summary(MyModel(), input_size=(1, 2500, 12))"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" ZeroPad2d-1 [-1, 1, 2514, 12] 0\n",
" Conv2d-2 [-1, 64, 1250, 12] 960\n",
" BatchNorm2d-3 [-1, 64, 1250, 12] 128\n",
" BasicConv2d-4 [-1, 64, 1250, 12] 0\n",
" ZeroPad2d-5 [-1, 64, 1250, 12] 0\n",
" Conv2d-6 [-1, 21, 1250, 12] 1,344\n",
" BatchNorm2d-7 [-1, 21, 1250, 12] 42\n",
" BasicConv2d-8 [-1, 21, 1250, 12] 0\n",
" ZeroPad2d-9 [-1, 64, 1250, 12] 0\n",
" Conv2d-10 [-1, 32, 1250, 12] 2,048\n",
" BatchNorm2d-11 [-1, 32, 1250, 12] 64\n",
" BasicConv2d-12 [-1, 32, 1250, 12] 0\n",
" ZeroPad2d-13 [-1, 32, 1252, 14] 0\n",
" Conv2d-14 [-1, 64, 1250, 12] 18,432\n",
" BatchNorm2d-15 [-1, 64, 1250, 12] 128\n",
" BasicConv2d-16 [-1, 64, 1250, 12] 0\n",
" ZeroPad2d-17 [-1, 64, 1250, 12] 0\n",
" Conv2d-18 [-1, 21, 1250, 12] 1,344\n",
" BatchNorm2d-19 [-1, 21, 1250, 12] 42\n",
" BasicConv2d-20 [-1, 21, 1250, 12] 0\n",
" ZeroPad2d-21 [-1, 21, 1252, 14] 0\n",
" Conv2d-22 [-1, 32, 1250, 12] 6,048\n",
" BatchNorm2d-23 [-1, 32, 1250, 12] 64\n",
" BasicConv2d-24 [-1, 32, 1250, 12] 0\n",
" ZeroPad2d-25 [-1, 32, 1252, 14] 0\n",
" Conv2d-26 [-1, 64, 1250, 12] 18,432\n",
" BatchNorm2d-27 [-1, 64, 1250, 12] 128\n",
" BasicConv2d-28 [-1, 64, 1250, 12] 0\n",
" BatchNorm2d-29 [-1, 149, 1250, 12] 298\n",
"Multi_2D_CNN_block-30 [-1, 149, 1250, 12] 0\n",
" ZeroPad2d-31 [-1, 149, 1250, 12] 0\n",
" Conv2d-32 [-1, 21, 1250, 12] 3,129\n",
" BatchNorm2d-33 [-1, 21, 1250, 12] 42\n",
" BasicConv2d-34 [-1, 21, 1250, 12] 0\n",
" ZeroPad2d-35 [-1, 149, 1250, 12] 0\n",
" Conv2d-36 [-1, 32, 1250, 12] 4,768\n",
" BatchNorm2d-37 [-1, 32, 1250, 12] 64\n",
" BasicConv2d-38 [-1, 32, 1250, 12] 0\n",
" ZeroPad2d-39 [-1, 32, 1252, 14] 0\n",
" Conv2d-40 [-1, 64, 1250, 12] 18,432\n",
" BatchNorm2d-41 [-1, 64, 1250, 12] 128\n",
" BasicConv2d-42 [-1, 64, 1250, 12] 0\n",
" ZeroPad2d-43 [-1, 149, 1250, 12] 0\n",
" Conv2d-44 [-1, 21, 1250, 12] 3,129\n",
" BatchNorm2d-45 [-1, 21, 1250, 12] 42\n",
" BasicConv2d-46 [-1, 21, 1250, 12] 0\n",
" ZeroPad2d-47 [-1, 21, 1252, 14] 0\n",
" Conv2d-48 [-1, 32, 1250, 12] 6,048\n",
" BatchNorm2d-49 [-1, 32, 1250, 12] 64\n",
" BasicConv2d-50 [-1, 32, 1250, 12] 0\n",
" ZeroPad2d-51 [-1, 32, 1252, 14] 0\n",
" Conv2d-52 [-1, 64, 1250, 12] 18,432\n",
" BatchNorm2d-53 [-1, 64, 1250, 12] 128\n",
" BasicConv2d-54 [-1, 64, 1250, 12] 0\n",
" BatchNorm2d-55 [-1, 149, 1250, 12] 298\n",
"Multi_2D_CNN_block-56 [-1, 149, 1250, 12] 0\n",
" MaxPool2d-57 [-1, 149, 416, 12] 0\n",
" ZeroPad2d-58 [-1, 149, 416, 12] 0\n",
" Conv2d-59 [-1, 32, 416, 12] 4,768\n",
" BatchNorm2d-60 [-1, 32, 416, 12] 64\n",
" BasicConv2d-61 [-1, 32, 416, 12] 0\n",
" ZeroPad2d-62 [-1, 149, 416, 12] 0\n",
" Conv2d-63 [-1, 48, 416, 12] 7,152\n",
" BatchNorm2d-64 [-1, 48, 416, 12] 96\n",
" BasicConv2d-65 [-1, 48, 416, 12] 0\n",
" ZeroPad2d-66 [-1, 48, 418, 14] 0\n",
" Conv2d-67 [-1, 96, 416, 12] 41,472\n",
" BatchNorm2d-68 [-1, 96, 416, 12] 192\n",
" BasicConv2d-69 [-1, 96, 416, 12] 0\n",
" ZeroPad2d-70 [-1, 149, 416, 12] 0\n",
" Conv2d-71 [-1, 32, 416, 12] 4,768\n",
" BatchNorm2d-72 [-1, 32, 416, 12] 64\n",
" BasicConv2d-73 [-1, 32, 416, 12] 0\n",
" ZeroPad2d-74 [-1, 32, 418, 14] 0\n",
" Conv2d-75 [-1, 48, 416, 12] 13,824\n",
" BatchNorm2d-76 [-1, 48, 416, 12] 96\n",
" BasicConv2d-77 [-1, 48, 416, 12] 0\n",
" ZeroPad2d-78 [-1, 48, 418, 14] 0\n",
" Conv2d-79 [-1, 96, 416, 12] 41,472\n",
" BatchNorm2d-80 [-1, 96, 416, 12] 192\n",
" BasicConv2d-81 [-1, 96, 416, 12] 0\n",
" BatchNorm2d-82 [-1, 224, 416, 12] 448\n",
"Multi_2D_CNN_block-83 [-1, 224, 416, 12] 0\n",
" ZeroPad2d-84 [-1, 224, 416, 12] 0\n",
" Conv2d-85 [-1, 32, 416, 12] 7,168\n",
" BatchNorm2d-86 [-1, 32, 416, 12] 64\n",
" BasicConv2d-87 [-1, 32, 416, 12] 0\n",
" ZeroPad2d-88 [-1, 224, 416, 12] 0\n",
" Conv2d-89 [-1, 48, 416, 12] 10,752\n",
" BatchNorm2d-90 [-1, 48, 416, 12] 96\n",
" BasicConv2d-91 [-1, 48, 416, 12] 0\n",
" ZeroPad2d-92 [-1, 48, 418, 14] 0\n",
" Conv2d-93 [-1, 96, 416, 12] 41,472\n",
" BatchNorm2d-94 [-1, 96, 416, 12] 192\n",
" BasicConv2d-95 [-1, 96, 416, 12] 0\n",
" ZeroPad2d-96 [-1, 224, 416, 12] 0\n",
" Conv2d-97 [-1, 32, 416, 12] 7,168\n",
" BatchNorm2d-98 [-1, 32, 416, 12] 64\n",
" BasicConv2d-99 [-1, 32, 416, 12] 0\n",
" ZeroPad2d-100 [-1, 32, 418, 14] 0\n",
" Conv2d-101 [-1, 48, 416, 12] 13,824\n",
" BatchNorm2d-102 [-1, 48, 416, 12] 96\n",
" BasicConv2d-103 [-1, 48, 416, 12] 0\n",
" ZeroPad2d-104 [-1, 48, 418, 14] 0\n",
" Conv2d-105 [-1, 96, 416, 12] 41,472\n",
" BatchNorm2d-106 [-1, 96, 416, 12] 192\n",
" BasicConv2d-107 [-1, 96, 416, 12] 0\n",
" BatchNorm2d-108 [-1, 224, 416, 12] 448\n",
"Multi_2D_CNN_block-109 [-1, 224, 416, 12] 0\n",
" MaxPool2d-110 [-1, 224, 138, 12] 0\n",
" ZeroPad2d-111 [-1, 224, 138, 12] 0\n",
" Conv2d-112 [-1, 42, 138, 12] 9,408\n",
" BatchNorm2d-113 [-1, 42, 138, 12] 84\n",
" BasicConv2d-114 [-1, 42, 138, 12] 0\n",
" ZeroPad2d-115 [-1, 224, 138, 12] 0\n",
" Conv2d-116 [-1, 64, 138, 12] 14,336\n",
" BatchNorm2d-117 [-1, 64, 138, 12] 128\n",
" BasicConv2d-118 [-1, 64, 138, 12] 0\n",
" ZeroPad2d-119 [-1, 64, 140, 14] 0\n",
" Conv2d-120 [-1, 128, 138, 12] 73,728\n",
" BatchNorm2d-121 [-1, 128, 138, 12] 256\n",
" BasicConv2d-122 [-1, 128, 138, 12] 0\n",
" ZeroPad2d-123 [-1, 224, 138, 12] 0\n",
" Conv2d-124 [-1, 42, 138, 12] 9,408\n",
" BatchNorm2d-125 [-1, 42, 138, 12] 84\n",
" BasicConv2d-126 [-1, 42, 138, 12] 0\n",
" ZeroPad2d-127 [-1, 42, 140, 14] 0\n",
" Conv2d-128 [-1, 64, 138, 12] 24,192\n",
" BatchNorm2d-129 [-1, 64, 138, 12] 128\n",
" BasicConv2d-130 [-1, 64, 138, 12] 0\n",
" ZeroPad2d-131 [-1, 64, 140, 14] 0\n",
" Conv2d-132 [-1, 128, 138, 12] 73,728\n",
" BatchNorm2d-133 [-1, 128, 138, 12] 256\n",
" BasicConv2d-134 [-1, 128, 138, 12] 0\n",
" BatchNorm2d-135 [-1, 298, 138, 12] 596\n",
"Multi_2D_CNN_block-136 [-1, 298, 138, 12] 0\n",
" ZeroPad2d-137 [-1, 298, 138, 12] 0\n",
" Conv2d-138 [-1, 42, 138, 12] 12,516\n",
" BatchNorm2d-139 [-1, 42, 138, 12] 84\n",
" BasicConv2d-140 [-1, 42, 138, 12] 0\n",
" ZeroPad2d-141 [-1, 298, 138, 12] 0\n",
" Conv2d-142 [-1, 64, 138, 12] 19,072\n",
" BatchNorm2d-143 [-1, 64, 138, 12] 128\n",
" BasicConv2d-144 [-1, 64, 138, 12] 0\n",
" ZeroPad2d-145 [-1, 64, 140, 14] 0\n",
" Conv2d-146 [-1, 128, 138, 12] 73,728\n",
" BatchNorm2d-147 [-1, 128, 138, 12] 256\n",
" BasicConv2d-148 [-1, 128, 138, 12] 0\n",
" ZeroPad2d-149 [-1, 298, 138, 12] 0\n",
" Conv2d-150 [-1, 42, 138, 12] 12,516\n",
" BatchNorm2d-151 [-1, 42, 138, 12] 84\n",
" BasicConv2d-152 [-1, 42, 138, 12] 0\n",
" ZeroPad2d-153 [-1, 42, 140, 14] 0\n",
" Conv2d-154 [-1, 64, 138, 12] 24,192\n",
" BatchNorm2d-155 [-1, 64, 138, 12] 128\n",
" BasicConv2d-156 [-1, 64, 138, 12] 0\n",
" ZeroPad2d-157 [-1, 64, 140, 14] 0\n",
" Conv2d-158 [-1, 128, 138, 12] 73,728\n",
" BatchNorm2d-159 [-1, 128, 138, 12] 256\n",
" BasicConv2d-160 [-1, 128, 138, 12] 0\n",
" BatchNorm2d-161 [-1, 298, 138, 12] 596\n",
"Multi_2D_CNN_block-162 [-1, 298, 138, 12] 0\n",
" MaxPool2d-163 [-1, 298, 69, 12] 0\n",
" ZeroPad2d-164 [-1, 298, 69, 12] 0\n",
" Conv2d-165 [-1, 64, 69, 12] 19,072\n",
" BatchNorm2d-166 [-1, 64, 69, 12] 128\n",
" BasicConv2d-167 [-1, 64, 69, 12] 0\n",
" ZeroPad2d-168 [-1, 298, 69, 12] 0\n",
" Conv2d-169 [-1, 96, 69, 12] 28,608\n",
" BatchNorm2d-170 [-1, 96, 69, 12] 192\n",
" BasicConv2d-171 [-1, 96, 69, 12] 0\n",
" ZeroPad2d-172 [-1, 96, 71, 14] 0\n",
" Conv2d-173 [-1, 192, 69, 12] 165,888\n",
" BatchNorm2d-174 [-1, 192, 69, 12] 384\n",
" BasicConv2d-175 [-1, 192, 69, 12] 0\n",
" ZeroPad2d-176 [-1, 298, 69, 12] 0\n",
" Conv2d-177 [-1, 64, 69, 12] 19,072\n",
" BatchNorm2d-178 [-1, 64, 69, 12] 128\n",
" BasicConv2d-179 [-1, 64, 69, 12] 0\n",
" ZeroPad2d-180 [-1, 64, 71, 14] 0\n",
" Conv2d-181 [-1, 96, 69, 12] 55,296\n",
" BatchNorm2d-182 [-1, 96, 69, 12] 192\n",
" BasicConv2d-183 [-1, 96, 69, 12] 0\n",
" ZeroPad2d-184 [-1, 96, 71, 14] 0\n",
" Conv2d-185 [-1, 192, 69, 12] 165,888\n",
" BatchNorm2d-186 [-1, 192, 69, 12] 384\n",
" BasicConv2d-187 [-1, 192, 69, 12] 0\n",
" BatchNorm2d-188 [-1, 448, 69, 12] 896\n",
"Multi_2D_CNN_block-189 [-1, 448, 69, 12] 0\n",
" ZeroPad2d-190 [-1, 448, 69, 12] 0\n",
" Conv2d-191 [-1, 64, 69, 12] 28,672\n",
" BatchNorm2d-192 [-1, 64, 69, 12] 128\n",
" BasicConv2d-193 [-1, 64, 69, 12] 0\n",
" ZeroPad2d-194 [-1, 448, 69, 12] 0\n",
" Conv2d-195 [-1, 96, 69, 12] 43,008\n",
" BatchNorm2d-196 [-1, 96, 69, 12] 192\n",
" BasicConv2d-197 [-1, 96, 69, 12] 0\n",
" ZeroPad2d-198 [-1, 96, 71, 14] 0\n",
" Conv2d-199 [-1, 192, 69, 12] 165,888\n",
" BatchNorm2d-200 [-1, 192, 69, 12] 384\n",
" BasicConv2d-201 [-1, 192, 69, 12] 0\n",
" ZeroPad2d-202 [-1, 448, 69, 12] 0\n",
" Conv2d-203 [-1, 64, 69, 12] 28,672\n",
" BatchNorm2d-204 [-1, 64, 69, 12] 128\n",
" BasicConv2d-205 [-1, 64, 69, 12] 0\n",
" ZeroPad2d-206 [-1, 64, 71, 14] 0\n",
" Conv2d-207 [-1, 96, 69, 12] 55,296\n",
" BatchNorm2d-208 [-1, 96, 69, 12] 192\n",
" BasicConv2d-209 [-1, 96, 69, 12] 0\n",
" ZeroPad2d-210 [-1, 96, 71, 14] 0\n",
" Conv2d-211 [-1, 192, 69, 12] 165,888\n",
" BatchNorm2d-212 [-1, 192, 69, 12] 384\n",
" BasicConv2d-213 [-1, 192, 69, 12] 0\n",
" BatchNorm2d-214 [-1, 448, 69, 12] 896\n",
"Multi_2D_CNN_block-215 [-1, 448, 69, 12] 0\n",
" ZeroPad2d-216 [-1, 448, 69, 12] 0\n",
" Conv2d-217 [-1, 85, 69, 12] 38,080\n",
" BatchNorm2d-218 [-1, 85, 69, 12] 170\n",
" BasicConv2d-219 [-1, 85, 69, 12] 0\n",
" ZeroPad2d-220 [-1, 448, 69, 12] 0\n",
" Conv2d-221 [-1, 128, 69, 12] 57,344\n",
" BatchNorm2d-222 [-1, 128, 69, 12] 256\n",
" BasicConv2d-223 [-1, 128, 69, 12] 0\n",
" ZeroPad2d-224 [-1, 128, 71, 14] 0\n",
" Conv2d-225 [-1, 256, 69, 12] 294,912\n",
" BatchNorm2d-226 [-1, 256, 69, 12] 512\n",
" BasicConv2d-227 [-1, 256, 69, 12] 0\n",
" ZeroPad2d-228 [-1, 448, 69, 12] 0\n",
" Conv2d-229 [-1, 85, 69, 12] 38,080\n",
" BatchNorm2d-230 [-1, 85, 69, 12] 170\n",
" BasicConv2d-231 [-1, 85, 69, 12] 0\n",
" ZeroPad2d-232 [-1, 85, 71, 14] 0\n",
" Conv2d-233 [-1, 128, 69, 12] 97,920\n",
" BatchNorm2d-234 [-1, 128, 69, 12] 256\n",
" BasicConv2d-235 [-1, 128, 69, 12] 0\n",
" ZeroPad2d-236 [-1, 128, 71, 14] 0\n",
" Conv2d-237 [-1, 256, 69, 12] 294,912\n",
" BatchNorm2d-238 [-1, 256, 69, 12] 512\n",
" BasicConv2d-239 [-1, 256, 69, 12] 0\n",
" BatchNorm2d-240 [-1, 597, 69, 12] 1,194\n",
"Multi_2D_CNN_block-241 [-1, 597, 69, 12] 0\n",
" MaxPool2d-242 [-1, 597, 34, 12] 0\n",
" ZeroPad2d-243 [-1, 597, 34, 12] 0\n",
" Conv2d-244 [-1, 106, 34, 12] 63,282\n",
" BatchNorm2d-245 [-1, 106, 34, 12] 212\n",
" BasicConv2d-246 [-1, 106, 34, 12] 0\n",
" ZeroPad2d-247 [-1, 597, 34, 12] 0\n",
" Conv2d-248 [-1, 160, 34, 12] 95,520\n",
" BatchNorm2d-249 [-1, 160, 34, 12] 320\n",
" BasicConv2d-250 [-1, 160, 34, 12] 0\n",
" ZeroPad2d-251 [-1, 160, 36, 14] 0\n",
" Conv2d-252 [-1, 320, 34, 12] 460,800\n",
" BatchNorm2d-253 [-1, 320, 34, 12] 640\n",
" BasicConv2d-254 [-1, 320, 34, 12] 0\n",
" ZeroPad2d-255 [-1, 597, 34, 12] 0\n",
" Conv2d-256 [-1, 106, 34, 12] 63,282\n",
" BatchNorm2d-257 [-1, 106, 34, 12] 212\n",
" BasicConv2d-258 [-1, 106, 34, 12] 0\n",
" ZeroPad2d-259 [-1, 106, 36, 14] 0\n",
" Conv2d-260 [-1, 160, 34, 12] 152,640\n",
" BatchNorm2d-261 [-1, 160, 34, 12] 320\n",
" BasicConv2d-262 [-1, 160, 34, 12] 0\n",
" ZeroPad2d-263 [-1, 160, 36, 14] 0\n",
" Conv2d-264 [-1, 320, 34, 12] 460,800\n",
" BatchNorm2d-265 [-1, 320, 34, 12] 640\n",
" BasicConv2d-266 [-1, 320, 34, 12] 0\n",
" BatchNorm2d-267 [-1, 746, 34, 12] 1,492\n",
"Multi_2D_CNN_block-268 [-1, 746, 34, 12] 0\n",
" ZeroPad2d-269 [-1, 746, 34, 12] 0\n",
" Conv2d-270 [-1, 128, 34, 12] 95,488\n",
" BatchNorm2d-271 [-1, 128, 34, 12] 256\n",
" BasicConv2d-272 [-1, 128, 34, 12] 0\n",
" ZeroPad2d-273 [-1, 746, 34, 12] 0\n",
" Conv2d-274 [-1, 192, 34, 12] 143,232\n",
" BatchNorm2d-275 [-1, 192, 34, 12] 384\n",
" BasicConv2d-276 [-1, 192, 34, 12] 0\n",
" ZeroPad2d-277 [-1, 192, 36, 14] 0\n",
" Conv2d-278 [-1, 384, 34, 12] 663,552\n",
" BatchNorm2d-279 [-1, 384, 34, 12] 768\n",
" BasicConv2d-280 [-1, 384, 34, 12] 0\n",
" ZeroPad2d-281 [-1, 746, 34, 12] 0\n",
" Conv2d-282 [-1, 128, 34, 12] 95,488\n",
" BatchNorm2d-283 [-1, 128, 34, 12] 256\n",
" BasicConv2d-284 [-1, 128, 34, 12] 0\n",
" ZeroPad2d-285 [-1, 128, 36, 14] 0\n",
" Conv2d-286 [-1, 192, 34, 12] 221,184\n",
" BatchNorm2d-287 [-1, 192, 34, 12] 384\n",
" BasicConv2d-288 [-1, 192, 34, 12] 0\n",
" ZeroPad2d-289 [-1, 192, 36, 14] 0\n",
" Conv2d-290 [-1, 384, 34, 12] 663,552\n",
" BatchNorm2d-291 [-1, 384, 34, 12] 768\n",
" BasicConv2d-292 [-1, 384, 34, 12] 0\n",
" BatchNorm2d-293 [-1, 896, 34, 12] 1,792\n",
"Multi_2D_CNN_block-294 [-1, 896, 34, 12] 0\n",
" ZeroPad2d-295 [-1, 896, 34, 12] 0\n",
" Conv2d-296 [-1, 149, 34, 12] 133,504\n",
" BatchNorm2d-297 [-1, 149, 34, 12] 298\n",
" BasicConv2d-298 [-1, 149, 34, 12] 0\n",
" ZeroPad2d-299 [-1, 896, 34, 12] 0\n",
" Conv2d-300 [-1, 224, 34, 12] 200,704\n",
" BatchNorm2d-301 [-1, 224, 34, 12] 448\n",
" BasicConv2d-302 [-1, 224, 34, 12] 0\n",
" ZeroPad2d-303 [-1, 224, 36, 14] 0\n",
" Conv2d-304 [-1, 448, 34, 12] 903,168\n",
" BatchNorm2d-305 [-1, 448, 34, 12] 896\n",
" BasicConv2d-306 [-1, 448, 34, 12] 0\n",
" ZeroPad2d-307 [-1, 896, 34, 12] 0\n",
" Conv2d-308 [-1, 149, 34, 12] 133,504\n",
" BatchNorm2d-309 [-1, 149, 34, 12] 298\n",
" BasicConv2d-310 [-1, 149, 34, 12] 0\n",
" ZeroPad2d-311 [-1, 149, 36, 14] 0\n",
" Conv2d-312 [-1, 224, 34, 12] 300,384\n",
" BatchNorm2d-313 [-1, 224, 34, 12] 448\n",
" BasicConv2d-314 [-1, 224, 34, 12] 0\n",
" ZeroPad2d-315 [-1, 224, 36, 14] 0\n",
" Conv2d-316 [-1, 448, 34, 12] 903,168\n",
" BatchNorm2d-317 [-1, 448, 34, 12] 896\n",
" BasicConv2d-318 [-1, 448, 34, 12] 0\n",
" BatchNorm2d-319 [-1, 1045, 34, 12] 2,090\n",
"Multi_2D_CNN_block-320 [-1, 1045, 34, 12] 0\n",
" MaxPool2d-321 [-1, 1045, 17, 12] 0\n",
" ZeroPad2d-322 [-1, 1045, 17, 12] 0\n",
" Conv2d-323 [-1, 170, 17, 12] 177,650\n",
" BatchNorm2d-324 [-1, 170, 17, 12] 340\n",
" BasicConv2d-325 [-1, 170, 17, 12] 0\n",
" ZeroPad2d-326 [-1, 1045, 17, 12] 0\n",
" Conv2d-327 [-1, 256, 17, 12] 267,520\n",
" BatchNorm2d-328 [-1, 256, 17, 12] 512\n",
" BasicConv2d-329 [-1, 256, 17, 12] 0\n",
" ZeroPad2d-330 [-1, 256, 19, 14] 0\n",
" Conv2d-331 [-1, 512, 17, 12] 1,179,648\n",
" BatchNorm2d-332 [-1, 512, 17, 12] 1,024\n",
" BasicConv2d-333 [-1, 512, 17, 12] 0\n",
" ZeroPad2d-334 [-1, 1045, 17, 12] 0\n",
" Conv2d-335 [-1, 170, 17, 12] 177,650\n",
" BatchNorm2d-336 [-1, 170, 17, 12] 340\n",
" BasicConv2d-337 [-1, 170, 17, 12] 0\n",
" ZeroPad2d-338 [-1, 170, 19, 14] 0\n",
" Conv2d-339 [-1, 256, 17, 12] 391,680\n",
" BatchNorm2d-340 [-1, 256, 17, 12] 512\n",
" BasicConv2d-341 [-1, 256, 17, 12] 0\n",
" ZeroPad2d-342 [-1, 256, 19, 14] 0\n",
" Conv2d-343 [-1, 512, 17, 12] 1,179,648\n",
" BatchNorm2d-344 [-1, 512, 17, 12] 1,024\n",
" BasicConv2d-345 [-1, 512, 17, 12] 0\n",
" BatchNorm2d-346 [-1, 1194, 17, 12] 2,388\n",
"Multi_2D_CNN_block-347 [-1, 1194, 17, 12] 0\n",
" ZeroPad2d-348 [-1, 1194, 17, 12] 0\n",
" Conv2d-349 [-1, 170, 17, 12] 202,980\n",
" BatchNorm2d-350 [-1, 170, 17, 12] 340\n",
" BasicConv2d-351 [-1, 170, 17, 12] 0\n",
" ZeroPad2d-352 [-1, 1194, 17, 12] 0\n",
" Conv2d-353 [-1, 256, 17, 12] 305,664\n",
" BatchNorm2d-354 [-1, 256, 17, 12] 512\n",
" BasicConv2d-355 [-1, 256, 17, 12] 0\n",
" ZeroPad2d-356 [-1, 256, 19, 14] 0\n",
" Conv2d-357 [-1, 512, 17, 12] 1,179,648\n",
" BatchNorm2d-358 [-1, 512, 17, 12] 1,024\n",
" BasicConv2d-359 [-1, 512, 17, 12] 0\n",
" ZeroPad2d-360 [-1, 1194, 17, 12] 0\n",
" Conv2d-361 [-1, 170, 17, 12] 202,980\n",
" BatchNorm2d-362 [-1, 170, 17, 12] 340\n",
" BasicConv2d-363 [-1, 170, 17, 12] 0\n",
" ZeroPad2d-364 [-1, 170, 19, 14] 0\n",
" Conv2d-365 [-1, 256, 17, 12] 391,680\n",
" BatchNorm2d-366 [-1, 256, 17, 12] 512\n",
" BasicConv2d-367 [-1, 256, 17, 12] 0\n",
" ZeroPad2d-368 [-1, 256, 19, 14] 0\n",
" Conv2d-369 [-1, 512, 17, 12] 1,179,648\n",
" BatchNorm2d-370 [-1, 512, 17, 12] 1,024\n",
" BasicConv2d-371 [-1, 512, 17, 12] 0\n",
" BatchNorm2d-372 [-1, 1194, 17, 12] 2,388\n",
"Multi_2D_CNN_block-373 [-1, 1194, 17, 12] 0\n",
" ZeroPad2d-374 [-1, 1194, 17, 12] 0\n",
" Conv2d-375 [-1, 170, 17, 12] 202,980\n",
" BatchNorm2d-376 [-1, 170, 17, 12] 340\n",
" BasicConv2d-377 [-1, 170, 17, 12] 0\n",
" ZeroPad2d-378 [-1, 1194, 17, 12] 0\n",
" Conv2d-379 [-1, 256, 17, 12] 305,664\n",
" BatchNorm2d-380 [-1, 256, 17, 12] 512\n",
" BasicConv2d-381 [-1, 256, 17, 12] 0\n",
" ZeroPad2d-382 [-1, 256, 19, 14] 0\n",
" Conv2d-383 [-1, 512, 17, 12] 1,179,648\n",
" BatchNorm2d-384 [-1, 512, 17, 12] 1,024\n",
" BasicConv2d-385 [-1, 512, 17, 12] 0\n",
" ZeroPad2d-386 [-1, 1194, 17, 12] 0\n",
" Conv2d-387 [-1, 170, 17, 12] 202,980\n",
" BatchNorm2d-388 [-1, 170, 17, 12] 340\n",
" BasicConv2d-389 [-1, 170, 17, 12] 0\n",
" ZeroPad2d-390 [-1, 170, 19, 14] 0\n",
" Conv2d-391 [-1, 256, 17, 12] 391,680\n",
" BatchNorm2d-392 [-1, 256, 17, 12] 512\n",
" BasicConv2d-393 [-1, 256, 17, 12] 0\n",
" ZeroPad2d-394 [-1, 256, 19, 14] 0\n",
" Conv2d-395 [-1, 512, 17, 12] 1,179,648\n",
" BatchNorm2d-396 [-1, 512, 17, 12] 1,024\n",
" BasicConv2d-397 [-1, 512, 17, 12] 0\n",
" BatchNorm2d-398 [-1, 1194, 17, 12] 2,388\n",
"Multi_2D_CNN_block-399 [-1, 1194, 17, 12] 0\n",
" MaxPool2d-400 [-1, 1194, 8, 12] 0\n",
" ZeroPad2d-401 [-1, 1194, 8, 12] 0\n",
" Conv2d-402 [-1, 256, 8, 12] 305,664\n",
" BatchNorm2d-403 [-1, 256, 8, 12] 512\n",
" BasicConv2d-404 [-1, 256, 8, 12] 0\n",
" ZeroPad2d-405 [-1, 1194, 8, 12] 0\n",
" Conv2d-406 [-1, 384, 8, 12] 458,496\n",
" BatchNorm2d-407 [-1, 384, 8, 12] 768\n",
" BasicConv2d-408 [-1, 384, 8, 12] 0\n",
" ZeroPad2d-409 [-1, 384, 10, 14] 0\n",
" Conv2d-410 [-1, 768, 8, 12] 2,654,208\n",
" BatchNorm2d-411 [-1, 768, 8, 12] 1,536\n",
" BasicConv2d-412 [-1, 768, 8, 12] 0\n",
" ZeroPad2d-413 [-1, 1194, 8, 12] 0\n",
" Conv2d-414 [-1, 256, 8, 12] 305,664\n",
" BatchNorm2d-415 [-1, 256, 8, 12] 512\n",
" BasicConv2d-416 [-1, 256, 8, 12] 0\n",
" ZeroPad2d-417 [-1, 256, 10, 14] 0\n",
" Conv2d-418 [-1, 384, 8, 12] 884,736\n",
" BatchNorm2d-419 [-1, 384, 8, 12] 768\n",
" BasicConv2d-420 [-1, 384, 8, 12] 0\n",
" ZeroPad2d-421 [-1, 384, 10, 14] 0\n",
" Conv2d-422 [-1, 768, 8, 12] 2,654,208\n",
" BatchNorm2d-423 [-1, 768, 8, 12] 1,536\n",
" BasicConv2d-424 [-1, 768, 8, 12] 0\n",
" BatchNorm2d-425 [-1, 1792, 8, 12] 3,584\n",
"Multi_2D_CNN_block-426 [-1, 1792, 8, 12] 0\n",
" ZeroPad2d-427 [-1, 1792, 8, 12] 0\n",
" Conv2d-428 [-1, 298, 8, 12] 534,016\n",
" BatchNorm2d-429 [-1, 298, 8, 12] 596\n",
" BasicConv2d-430 [-1, 298, 8, 12] 0\n",
" ZeroPad2d-431 [-1, 1792, 8, 12] 0\n",
" Conv2d-432 [-1, 448, 8, 12] 802,816\n",
" BatchNorm2d-433 [-1, 448, 8, 12] 896\n",
" BasicConv2d-434 [-1, 448, 8, 12] 0\n",
" ZeroPad2d-435 [-1, 448, 10, 14] 0\n",
" Conv2d-436 [-1, 896, 8, 12] 3,612,672\n",
" BatchNorm2d-437 [-1, 896, 8, 12] 1,792\n",
" BasicConv2d-438 [-1, 896, 8, 12] 0\n",
" ZeroPad2d-439 [-1, 1792, 8, 12] 0\n",
" Conv2d-440 [-1, 298, 8, 12] 534,016\n",
" BatchNorm2d-441 [-1, 298, 8, 12] 596\n",
" BasicConv2d-442 [-1, 298, 8, 12] 0\n",
" ZeroPad2d-443 [-1, 298, 10, 14] 0\n",
" Conv2d-444 [-1, 448, 8, 12] 1,201,536\n",
" BatchNorm2d-445 [-1, 448, 8, 12] 896\n",
" BasicConv2d-446 [-1, 448, 8, 12] 0\n",
" ZeroPad2d-447 [-1, 448, 10, 14] 0\n",
" Conv2d-448 [-1, 896, 8, 12] 3,612,672\n",
" BatchNorm2d-449 [-1, 896, 8, 12] 1,792\n",
" BasicConv2d-450 [-1, 896, 8, 12] 0\n",
" BatchNorm2d-451 [-1, 2090, 8, 12] 4,180\n",
"Multi_2D_CNN_block-452 [-1, 2090, 8, 12] 0\n",
" ZeroPad2d-453 [-1, 2090, 8, 12] 0\n",
" Conv2d-454 [-1, 341, 8, 12] 712,690\n",
" BatchNorm2d-455 [-1, 341, 8, 12] 682\n",
" BasicConv2d-456 [-1, 341, 8, 12] 0\n",
" ZeroPad2d-457 [-1, 2090, 8, 12] 0\n",
" Conv2d-458 [-1, 512, 8, 12] 1,070,080\n",
" BatchNorm2d-459 [-1, 512, 8, 12] 1,024\n",
" BasicConv2d-460 [-1, 512, 8, 12] 0\n",
" ZeroPad2d-461 [-1, 512, 10, 14] 0\n",
" Conv2d-462 [-1, 1024, 8, 12] 4,718,592\n",
" BatchNorm2d-463 [-1, 1024, 8, 12] 2,048\n",
" BasicConv2d-464 [-1, 1024, 8, 12] 0\n",
" ZeroPad2d-465 [-1, 2090, 8, 12] 0\n",
" Conv2d-466 [-1, 341, 8, 12] 712,690\n",
" BatchNorm2d-467 [-1, 341, 8, 12] 682\n",
" BasicConv2d-468 [-1, 341, 8, 12] 0\n",
" ZeroPad2d-469 [-1, 341, 10, 14] 0\n",
" Conv2d-470 [-1, 512, 8, 12] 1,571,328\n",
" BatchNorm2d-471 [-1, 512, 8, 12] 1,024\n",
" BasicConv2d-472 [-1, 512, 8, 12] 0\n",
" ZeroPad2d-473 [-1, 512, 10, 14] 0\n",
" Conv2d-474 [-1, 1024, 8, 12] 4,718,592\n",
" BatchNorm2d-475 [-1, 1024, 8, 12] 2,048\n",
" BasicConv2d-476 [-1, 1024, 8, 12] 0\n",
" BatchNorm2d-477 [-1, 2389, 8, 12] 4,778\n",
"Multi_2D_CNN_block-478 [-1, 2389, 8, 12] 0\n",
"AdaptiveAvgPool2d-479 [-1, 2389, 1, 1] 0\n",
" Flatten-480 [-1, 2389] 0\n",
" Dropout-481 [-1, 2389] 0\n",
" Linear-482 [-1, 1] 2,390\n",
"================================================================\n",
"Total params: 49,719,414\n",
"Trainable params: 49,719,414\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n",
"Input size (MB): 0.11\n",
"Forward/backward pass size (MB): 884.58\n",
"Params size (MB): 189.66\n",
"Estimated Total Size (MB): 1074.36\n",
"----------------------------------------------------------------\n"
]
}
],
"source": [
"summary(Inception_Tabular_15_1(), input_size=(1, 2500, 12))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Net(\n",
" (conv1): Sequential(\n",
" (0): Conv2d(1, 64, kernel_size=(7, 3), stride=(1, 1))\n",
" (1): ReLU()\n",
" (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (4): ReLU()\n",
" (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (6): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (conv2): Sequential(\n",
" (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (1): ReLU()\n",
" (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (4): ReLU()\n",
" (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (6): MaxPool2d(kernel_size=(3, 1), stride=(1, 1), padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (conv3): Sequential(\n",
" (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (1): ReLU()\n",
" (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (4): ReLU()\n",
" (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (6): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (conv4): Sequential(\n",
" (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (1): ReLU()\n",
" (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (4): ReLU()\n",
" (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (6): MaxPool2d(kernel_size=(3, 1), stride=(1, 1), padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (conv5): Sequential(\n",
" (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (1): ReLU()\n",
" (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (4): ReLU()\n",
" (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (6): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (conv6): Sequential(\n",
" (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (1): ReLU()\n",
" (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (4): ReLU()\n",
" (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (6): MaxPool2d(kernel_size=(3, 1), stride=(1, 1), padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (conv7): Sequential(\n",
" (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (1): ReLU()\n",
" (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (4): ReLU()\n",
" (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (6): MaxPool2d(kernel_size=(3, 1), stride=(1, 1), padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (fc1): Sequential(\n",
" (0): Flatten()\n",
" (1): Linear(in_features=55680, out_features=64, bias=True)\n",
" (2): ReLU()\n",
" (3): Linear(in_features=64, out_features=32, bias=True)\n",
" (4): ReLU()\n",
" (5): Linear(in_features=32, out_features=1, bias=True)\n",
" )\n",
")\n"
]
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import TensorDataset\n",
"import torch.optim as optim\n",
"from torch.optim import lr_scheduler\n",
"import numpy as np\n",
"import torchvision\n",
"import torch.nn.functional as F\n",
"from torch.utils.data.sampler import SubsetRandomSampler\n",
"from torch.utils.data import DataLoader\n",
"from torchvision import datasets, models, transforms\n",
"from torchvision.transforms import Resize, ToTensor, Normalize\n",
"import matplotlib.pyplot as plt\n",
"from imblearn.under_sampling import RandomUnderSampler\n",
"\n",
"from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, roc_auc_score, \\\n",
" average_precision_score\n",
"from sklearn.model_selection import train_test_split\n",
"import time\n",
"import os\n",
"from pathlib import Path\n",
"from skimage import io\n",
"import copy\n",
"from torch import optim, cuda\n",
"import pandas as pd\n",
"import glob\n",
"from collections import Counter\n",
"# Useful for examining network\n",
"from functools import reduce\n",
"from operator import __add__\n",
"# from torchsummary import summary\n",
"import seaborn as sns\n",
"import warnings\n",
"# warnings.filterwarnings('ignore', category=FutureWarning)\n",
"from PIL import Image\n",
"from timeit import default_timer as timer\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Useful for examining network\n",
"from functools import reduce\n",
"from operator import __add__\n",
"from torchsummary import summary\n",
"\n",
"# from IPython.core.interactiveshell import InteractiveShell\n",
"import seaborn as sns\n",
"\n",
"import warnings\n",
"# warnings.filterwarnings('ignore', category=FutureWarning)\n",
"\n",
"# Image manipulations\n",
"from PIL import Image\n",
"\n",
"# Timing utility\n",
"from timeit import default_timer as timer\n",
"\n",
"# Visualizations\n",
"import matplotlib.pyplot as plt\n",
"class Net(nn.Module):\n",
"\n",
" # Convolution as a whole should span at least 1 beat, preferably more\n",
" # Input shape = (1,2500,12)\n",
"# summary(model, input_size=(1, 2500, 12))\n",
" # CLEAR EXPERIMENTS TO TRY\n",
"# Alter kernel size, right now drops to 10 channels at beginning then stays there\n",
"# Try increasing output channel size as you go deeper \n",
"# Alter stride to have larger image at FC layer\n",
"\n",
" def __init__(self):\n",
" super(Net, self).__init__()\n",
"\n",
" base_conv = 64\n",
" self.conv1 = nn.Sequential( \n",
" nn.Conv2d(in_channels=1, out_channels=base_conv, kernel_size=(7,3), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.MaxPool2d(kernel_size=(3,1), stride=(3,1))\n",
" )\n",
"\n",
" self.conv2 = nn.Sequential(\n",
" nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.MaxPool2d(kernel_size=(3,1), stride=(1,1))\n",
" )\n",
"\n",
" self.conv3 = nn.Sequential(\n",
" nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.MaxPool2d(kernel_size=(3,1), stride=(3,1))\n",
" )\n",
"\n",
" self.conv4 = nn.Sequential(\n",
" nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.MaxPool2d(kernel_size=(3,1), stride=(1,1))\n",
" )\n",
"\n",
" self.conv5 = nn.Sequential(\n",
" nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.MaxPool2d(kernel_size=(3,1), stride=(3,1))\n",
" )\n",
" self.conv6 = nn.Sequential(\n",
" nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.MaxPool2d(kernel_size=(3,1), stride=(1,1))\n",
" )\n",
"\n",
" self.conv7 = nn.Sequential(\n",
" nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.MaxPool2d(kernel_size=(3,1), stride=(1,1))\n",
" )\n",
"\n",
" self.fc1 = nn.Sequential(\n",
"# nn.AdaptiveAvgPool2d((1, 1)),\n",
" nn.Flatten(),\n",
" nn.Linear(64*87*10, 64), #64 kernel size, 2500 pooled to 29, \n",
" nn.ReLU(),\n",
" nn.Linear(64, 32),\n",
" nn.ReLU(),\n",
" nn.Linear(32, 1))\n",
"\n",
" def forward(self, x):\n",
" out = self.conv1(x)\n",
"# print(out.shape)\n",
" out = self.conv2(out)\n",
"# print(out.shape) \n",
" out = self.conv3(out)\n",
"# print(out.shape)\n",
" out = self.conv4(out)\n",
"# print(out.shape)\n",
" out = self.conv5(out)\n",
"# print(out.shape)\n",
" out = self.conv6(out)\n",
"# print(out.shape)\n",
" out = self.conv7(out)\n",
"# print(out.shape)\n",
" out = self.fc1(out)\n",
" return out\n",
"\n",
"model = Net()\n",
"print(model)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([2, 64, 831, 10])\n",
"torch.Size([2, 64, 829, 10])\n",
"torch.Size([2, 64, 276, 10])\n",
"torch.Size([2, 64, 274, 10])\n",
"torch.Size([2, 64, 91, 10])\n",
"torch.Size([2, 64, 89, 10])\n",
"torch.Size([2, 64, 87, 10])\n",
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" Conv2d-1 [-1, 64, 2494, 10] 1,408\n",
" ReLU-2 [-1, 64, 2494, 10] 0\n",
" BatchNorm2d-3 [-1, 64, 2494, 10] 128\n",
" Conv2d-4 [-1, 64, 2494, 10] 4,160\n",
" ReLU-5 [-1, 64, 2494, 10] 0\n",
" BatchNorm2d-6 [-1, 64, 2494, 10] 128\n",
" MaxPool2d-7 [-1, 64, 831, 10] 0\n",
" Conv2d-8 [-1, 64, 831, 10] 4,160\n",
" ReLU-9 [-1, 64, 831, 10] 0\n",
" BatchNorm2d-10 [-1, 64, 831, 10] 128\n",
" Conv2d-11 [-1, 64, 831, 10] 4,160\n",
" ReLU-12 [-1, 64, 831, 10] 0\n",
" BatchNorm2d-13 [-1, 64, 831, 10] 128\n",
" MaxPool2d-14 [-1, 64, 829, 10] 0\n",
" Conv2d-15 [-1, 64, 829, 10] 4,160\n",
" ReLU-16 [-1, 64, 829, 10] 0\n",
" BatchNorm2d-17 [-1, 64, 829, 10] 128\n",
" Conv2d-18 [-1, 64, 829, 10] 4,160\n",
" ReLU-19 [-1, 64, 829, 10] 0\n",
" BatchNorm2d-20 [-1, 64, 829, 10] 128\n",
" MaxPool2d-21 [-1, 64, 276, 10] 0\n",
" Conv2d-22 [-1, 64, 276, 10] 4,160\n",
" ReLU-23 [-1, 64, 276, 10] 0\n",
" BatchNorm2d-24 [-1, 64, 276, 10] 128\n",
" Conv2d-25 [-1, 64, 276, 10] 4,160\n",
" ReLU-26 [-1, 64, 276, 10] 0\n",
" BatchNorm2d-27 [-1, 64, 276, 10] 128\n",
" MaxPool2d-28 [-1, 64, 274, 10] 0\n",
" Conv2d-29 [-1, 64, 274, 10] 4,160\n",
" ReLU-30 [-1, 64, 274, 10] 0\n",
" BatchNorm2d-31 [-1, 64, 274, 10] 128\n",
" Conv2d-32 [-1, 64, 274, 10] 4,160\n",
" ReLU-33 [-1, 64, 274, 10] 0\n",
" BatchNorm2d-34 [-1, 64, 274, 10] 128\n",
" MaxPool2d-35 [-1, 64, 91, 10] 0\n",
" Conv2d-36 [-1, 64, 91, 10] 4,160\n",
" ReLU-37 [-1, 64, 91, 10] 0\n",
" BatchNorm2d-38 [-1, 64, 91, 10] 128\n",
" Conv2d-39 [-1, 64, 91, 10] 4,160\n",
" ReLU-40 [-1, 64, 91, 10] 0\n",
" BatchNorm2d-41 [-1, 64, 91, 10] 128\n",
" MaxPool2d-42 [-1, 64, 89, 10] 0\n",
" Conv2d-43 [-1, 64, 89, 10] 4,160\n",
" ReLU-44 [-1, 64, 89, 10] 0\n",
" BatchNorm2d-45 [-1, 64, 89, 10] 128\n",
" Conv2d-46 [-1, 64, 89, 10] 4,160\n",
" ReLU-47 [-1, 64, 89, 10] 0\n",
" BatchNorm2d-48 [-1, 64, 89, 10] 128\n",
" MaxPool2d-49 [-1, 64, 87, 10] 0\n",
" Flatten-50 [-1, 55680] 0\n",
" Linear-51 [-1, 64] 3,563,584\n",
" ReLU-52 [-1, 64] 0\n",
" Linear-53 [-1, 32] 2,080\n",
" ReLU-54 [-1, 32] 0\n",
" Linear-55 [-1, 1] 33\n",
"================================================================\n",
"Total params: 3,622,977\n",
"Trainable params: 3,622,977\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n",
"Input size (MB): 0.11\n",
"Forward/backward pass size (MB): 155.61\n",
"Params size (MB): 13.82\n",
"Estimated Total Size (MB): 169.54\n",
"----------------------------------------------------------------\n"
]
}
],
"source": [
"summary(model, input_size=(1, 2500, 12))"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Net(\n",
" (conv1): Sequential(\n",
" (0): Conv2d(1, 64, kernel_size=(7, 3), stride=(1, 1))\n",
" (1): ReLU()\n",
" (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (4): ReLU()\n",
" (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (6): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (conv2): Sequential(\n",
" (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (1): ReLU()\n",
" (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (4): ReLU()\n",
" (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (6): MaxPool2d(kernel_size=(3, 1), stride=(1, 1), padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (conv3): Sequential(\n",
" (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (1): ReLU()\n",
" (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (4): ReLU()\n",
" (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (6): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (conv4): Sequential(\n",
" (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (1): ReLU()\n",
" (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (4): ReLU()\n",
" (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (6): MaxPool2d(kernel_size=(3, 1), stride=(1, 1), padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (conv5): Sequential(\n",
" (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (1): ReLU()\n",
" (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (4): ReLU()\n",
" (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (6): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (conv6): Sequential(\n",
" (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (1): ReLU()\n",
" (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (4): ReLU()\n",
" (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (6): MaxPool2d(kernel_size=(3, 1), stride=(1, 1), padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (conv7): Sequential(\n",
" (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (1): ReLU()\n",
" (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (4): ReLU()\n",
" (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (6): MaxPool2d(kernel_size=(3, 1), stride=(1, 1), padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (fc1): Sequential(\n",
" (0): Flatten()\n",
" (1): Linear(in_features=55680, out_features=64, bias=True)\n",
" (2): ReLU()\n",
" (3): Linear(in_features=64, out_features=32, bias=True)\n",
" (4): ReLU()\n",
" (5): Linear(in_features=32, out_features=1, bias=True)\n",
" )\n",
")\n"
]
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import TensorDataset\n",
"import torch.optim as optim\n",
"from torch.optim import lr_scheduler\n",
"import numpy as np\n",
"import torchvision\n",
"import torch.nn.functional as F\n",
"from torch.utils.data.sampler import SubsetRandomSampler\n",
"from torch.utils.data import DataLoader\n",
"from torchvision import datasets, models, transforms\n",
"from torchvision.transforms import Resize, ToTensor, Normalize\n",
"import matplotlib.pyplot as plt\n",
"from imblearn.under_sampling import RandomUnderSampler\n",
"\n",
"from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, roc_auc_score, \\\n",
" average_precision_score\n",
"from sklearn.model_selection import train_test_split\n",
"import time\n",
"import os\n",
"from pathlib import Path\n",
"from skimage import io\n",
"import copy\n",
"from torch import optim, cuda\n",
"import pandas as pd\n",
"import glob\n",
"from collections import Counter\n",
"# Useful for examining network\n",
"from functools import reduce\n",
"from operator import __add__\n",
"# from torchsummary import summary\n",
"import seaborn as sns\n",
"import warnings\n",
"# warnings.filterwarnings('ignore', category=FutureWarning)\n",
"from PIL import Image\n",
"from timeit import default_timer as timer\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Useful for examining network\n",
"from functools import reduce\n",
"from operator import __add__\n",
"from torchsummary import summary\n",
"\n",
"# from IPython.core.interactiveshell import InteractiveShell\n",
"import seaborn as sns\n",
"\n",
"import warnings\n",
"# warnings.filterwarnings('ignore', category=FutureWarning)\n",
"\n",
"# Image manipulations\n",
"from PIL import Image\n",
"\n",
"# Timing utility\n",
"from timeit import default_timer as timer\n",
"\n",
"# Visualizations\n",
"import matplotlib.pyplot as plt\n",
"class Net(nn.Module):\n",
"\n",
" # Convolution as a whole should span at least 1 beat, preferably more\n",
" # Input shape = (1,2500,12)\n",
"# summary(model, input_size=(1, 2500, 12))\n",
" # CLEAR EXPERIMENTS TO TRY\n",
"# Alter kernel size, right now drops to 10 channels at beginning then stays there\n",
"# Try increasing output channel size as you go deeper \n",
"# Alter stride to have larger image at FC layer\n",
"\n",
" def __init__(self):\n",
" super(Net, self).__init__()\n",
"\n",
" base_conv = 64\n",
" self.conv1 = nn.Sequential( \n",
" nn.Conv2d(in_channels=1, out_channels=base_conv, kernel_size=(7,3), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.MaxPool2d(kernel_size=(3,1), stride=(3,1))\n",
" )\n",
"\n",
" self.conv2 = nn.Sequential(\n",
" nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.MaxPool2d(kernel_size=(3,1), stride=(1,1))\n",
" )\n",
"\n",
" self.conv3 = nn.Sequential(\n",
" nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.MaxPool2d(kernel_size=(3,1), stride=(3,1))\n",
" )\n",
"\n",
" self.conv4 = nn.Sequential(\n",
" nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.MaxPool2d(kernel_size=(3,1), stride=(1,1))\n",
" )\n",
"\n",
" self.conv5 = nn.Sequential(\n",
" nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.MaxPool2d(kernel_size=(3,1), stride=(3,1))\n",
" )\n",
" self.conv6 = nn.Sequential(\n",
" nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.MaxPool2d(kernel_size=(3,1), stride=(1,1))\n",
" )\n",
"\n",
" self.conv7 = nn.Sequential(\n",
" nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.MaxPool2d(kernel_size=(3,1), stride=(1,1))\n",
" )\n",
"\n",
" self.fc1 = nn.Sequential(\n",
"# nn.AdaptiveAvgPool2d((1, 1)),\n",
" nn.Flatten(),\n",
" nn.Linear(64*87*10, 64), #64 kernel size, 2500 pooled to 29, \n",
" nn.ReLU(),\n",
" nn.Linear(64, 32),\n",
" nn.ReLU(),\n",
" nn.Linear(32, 1))\n",
"\n",
" def forward(self, x):\n",
" out = self.conv1(x)\n",
"# print(out.shape)\n",
" out = self.conv2(out)\n",
"# print(out.shape) \n",
" out = self.conv3(out)\n",
"# print(out.shape)\n",
" out = self.conv4(out)\n",
"# print(out.shape)\n",
" out = self.conv5(out)\n",
"# print(out.shape)\n",
" out = self.conv6(out)\n",
"# print(out.shape)\n",
" out = self.conv7(out)\n",
"# print(out.shape)\n",
" out = self.fc1(out)\n",
" return out\n",
"\n",
"model = Net()\n",
"print(model)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 109,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Net(\n",
" (conv1): Sequential(\n",
" (0): Conv2d(1, 16, kernel_size=(5, 1), stride=(1, 1))\n",
" (1): ReLU()\n",
" (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))\n",
" (4): ReLU()\n",
" (5): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (6): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (conv2): Sequential(\n",
" (0): Conv2d(16, 16, kernel_size=(5, 1), stride=(1, 1))\n",
" (1): ReLU()\n",
" (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))\n",
" (4): ReLU()\n",
" (5): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (6): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (conv3): Sequential(\n",
" (0): Conv2d(16, 32, kernel_size=(5, 1), stride=(1, 1))\n",
" (1): ReLU()\n",
" (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))\n",
" (4): ReLU()\n",
" (5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (6): MaxPool2d(kernel_size=(4, 1), stride=(4, 1), padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (conv4): Sequential(\n",
" (0): Conv2d(32, 32, kernel_size=(3, 1), stride=(1, 1))\n",
" (1): ReLU()\n",
" (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))\n",
" (4): ReLU()\n",
" (5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (6): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (conv5): Sequential(\n",
" (0): Conv2d(32, 64, kernel_size=(3, 1), stride=(1, 1))\n",
" (1): ReLU()\n",
" (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (4): ReLU()\n",
" (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (6): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (conv6): Sequential(\n",
" (0): Conv2d(64, 64, kernel_size=(3, 1), stride=(1, 1))\n",
" (1): ReLU()\n",
" (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (4): ReLU()\n",
" (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (6): MaxPool2d(kernel_size=(4, 1), stride=(4, 1), padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (conv7): Sequential(\n",
" (0): Conv2d(64, 64, kernel_size=(1, 12), stride=(1, 1))\n",
" (1): ReLU()\n",
" (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" (fc1): Sequential(\n",
" (0): Flatten()\n",
" (1): Linear(in_features=512, out_features=64, bias=True)\n",
" (2): ReLU()\n",
" (3): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (4): Linear(in_features=64, out_features=32, bias=True)\n",
" (5): ReLU()\n",
" (6): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (7): Linear(in_features=32, out_features=1, bias=True)\n",
" )\n",
")\n"
]
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"class Net(nn.Module):\n",
" def __init__(self):\n",
" super(Net, self).__init__()\n",
"\n",
" base_conv = 16\n",
" self.conv1 = nn.Sequential( \n",
" nn.Conv2d(in_channels=1, out_channels=base_conv, kernel_size=(5,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.MaxPool2d(kernel_size=(2,1), stride=(2,1)) \n",
" )\n",
"\n",
" self.conv2 = nn.Sequential(\n",
" nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(5,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv),\n",
" nn.MaxPool2d(kernel_size=(2,1), stride=(2,1)) \n",
" )\n",
"\n",
" self.conv3 = nn.Sequential(\n",
" nn.Conv2d(in_channels=base_conv, out_channels=base_conv*2, kernel_size=(5,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv*2),\n",
" nn.Conv2d(in_channels=base_conv*2, out_channels=base_conv*2, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv*2),\n",
" nn.MaxPool2d(kernel_size=(4,1), stride=(4,1)) \n",
" )\n",
"\n",
" self.conv4 = nn.Sequential(\n",
" nn.Conv2d(in_channels=base_conv*2, out_channels=base_conv*2, kernel_size=(3,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv*2),\n",
" nn.Conv2d(in_channels=base_conv*2, out_channels=base_conv*2, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv*2),\n",
" nn.MaxPool2d(kernel_size=(2,1), stride=(2,1)) \n",
" )\n",
"\n",
" self.conv5 = nn.Sequential(\n",
" nn.Conv2d(in_channels=base_conv*2, out_channels=base_conv*4, kernel_size=(3,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv*4),\n",
" nn.Conv2d(in_channels=base_conv*4, out_channels=base_conv*4, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv*4),\n",
" nn.MaxPool2d(kernel_size=(2,1), stride=(2,1)) \n",
" )\n",
" self.conv6 = nn.Sequential(\n",
" nn.Conv2d(in_channels=base_conv*4, out_channels=base_conv*4, kernel_size=(3,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv*4),\n",
" nn.Conv2d(in_channels=base_conv*4, out_channels=base_conv*4, kernel_size=(1,1), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv*4),\n",
" nn.MaxPool2d(kernel_size=(4,1), stride=(4,1)) \n",
" )\n",
"\n",
" self.conv7 = nn.Sequential(\n",
" nn.Conv2d(in_channels=base_conv*4, out_channels=base_conv*4, kernel_size=(1,12), stride=1),\n",
" nn.ReLU(),\n",
" nn.BatchNorm2d(base_conv*4),\n",
" ) #Mayo describess using their 7th block as a spatial fusion block\n",
"\n",
" self.fc1 = nn.Sequential(\n",
"# nn.AdaptiveAvgPool2d((1, 1)),\n",
" nn.Flatten(),\n",
" nn.Linear(64*8*1, 64), #64 kernel size, #2500 decreased to 8, x1 lead\n",
" nn.ReLU(),\n",
" nn.BatchNorm1d(64),\n",
" nn.Linear(64, 32),\n",
" nn.ReLU(),\n",
" nn.BatchNorm1d(32),\n",
"# nn.Dropout(0.5),\n",
" nn.Linear(32, 1))\n",
"\n",
" def forward(self, x):\n",
" out = self.conv1(x)\n",
"# print(out.shape)\n",
" out = self.conv2(out)\n",
"# print(out.shape) \n",
" out = self.conv3(out)\n",
"# print(out.shape)\n",
" out = self.conv4(out)\n",
"# print(out.shape)\n",
" out = self.conv5(out)\n",
"# print(out.shape)\n",
" out = self.conv6(out)\n",
"# print(out.shape)\n",
" out = self.conv7(out)\n",
"# print(out.shape)\n",
" out = self.fc1(out)\n",
" return out\n",
"\n",
"model = Net()\n",
"print(model)"
]
},
{
"cell_type": "code",
"execution_count": 110,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" Conv2d-1 [-1, 16, 2496, 12] 96\n",
" ReLU-2 [-1, 16, 2496, 12] 0\n",
" BatchNorm2d-3 [-1, 16, 2496, 12] 32\n",
" Conv2d-4 [-1, 16, 2496, 12] 272\n",
" ReLU-5 [-1, 16, 2496, 12] 0\n",
" BatchNorm2d-6 [-1, 16, 2496, 12] 32\n",
" MaxPool2d-7 [-1, 16, 1248, 12] 0\n",
" Conv2d-8 [-1, 16, 1244, 12] 1,296\n",
" ReLU-9 [-1, 16, 1244, 12] 0\n",
" BatchNorm2d-10 [-1, 16, 1244, 12] 32\n",
" Conv2d-11 [-1, 16, 1244, 12] 272\n",
" ReLU-12 [-1, 16, 1244, 12] 0\n",
" BatchNorm2d-13 [-1, 16, 1244, 12] 32\n",
" MaxPool2d-14 [-1, 16, 622, 12] 0\n",
" Conv2d-15 [-1, 32, 618, 12] 2,592\n",
" ReLU-16 [-1, 32, 618, 12] 0\n",
" BatchNorm2d-17 [-1, 32, 618, 12] 64\n",
" Conv2d-18 [-1, 32, 618, 12] 1,056\n",
" ReLU-19 [-1, 32, 618, 12] 0\n",
" BatchNorm2d-20 [-1, 32, 618, 12] 64\n",
" MaxPool2d-21 [-1, 32, 154, 12] 0\n",
" Conv2d-22 [-1, 32, 152, 12] 3,104\n",
" ReLU-23 [-1, 32, 152, 12] 0\n",
" BatchNorm2d-24 [-1, 32, 152, 12] 64\n",
" Conv2d-25 [-1, 32, 152, 12] 1,056\n",
" ReLU-26 [-1, 32, 152, 12] 0\n",
" BatchNorm2d-27 [-1, 32, 152, 12] 64\n",
" MaxPool2d-28 [-1, 32, 76, 12] 0\n",
" Conv2d-29 [-1, 64, 74, 12] 6,208\n",
" ReLU-30 [-1, 64, 74, 12] 0\n",
" BatchNorm2d-31 [-1, 64, 74, 12] 128\n",
" Conv2d-32 [-1, 64, 74, 12] 4,160\n",
" ReLU-33 [-1, 64, 74, 12] 0\n",
" BatchNorm2d-34 [-1, 64, 74, 12] 128\n",
" MaxPool2d-35 [-1, 64, 37, 12] 0\n",
" Conv2d-36 [-1, 64, 35, 12] 12,352\n",
" ReLU-37 [-1, 64, 35, 12] 0\n",
" BatchNorm2d-38 [-1, 64, 35, 12] 128\n",
" Conv2d-39 [-1, 64, 35, 12] 4,160\n",
" ReLU-40 [-1, 64, 35, 12] 0\n",
" BatchNorm2d-41 [-1, 64, 35, 12] 128\n",
" MaxPool2d-42 [-1, 64, 8, 12] 0\n",
" Conv2d-43 [-1, 64, 8, 1] 49,216\n",
" ReLU-44 [-1, 64, 8, 1] 0\n",
" BatchNorm2d-45 [-1, 64, 8, 1] 128\n",
" Flatten-46 [-1, 512] 0\n",
" Linear-47 [-1, 64] 32,832\n",
" ReLU-48 [-1, 64] 0\n",
" BatchNorm1d-49 [-1, 64] 128\n",
" Linear-50 [-1, 32] 2,080\n",
" ReLU-51 [-1, 32] 0\n",
" BatchNorm1d-52 [-1, 32] 64\n",
" Linear-53 [-1, 1] 33\n",
"================================================================\n",
"Total params: 122,001\n",
"Trainable params: 122,001\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n",
"Input size (MB): 0.11\n",
"Forward/backward pass size (MB): 53.93\n",
"Params size (MB): 0.47\n",
"Estimated Total Size (MB): 54.51\n",
"----------------------------------------------------------------\n"
]
}
],
"source": [
"summary(model, input_size=(1, 2500, 12))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}