--- a +++ b/4-Models/Model_Parameter_Evaluator.ipynb @@ -0,0 +1,4166 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from torch.utils.data import TensorDataset\n", + "import torch.optim as optim\n", + "from torch.optim import lr_scheduler\n", + "import numpy as np\n", + "import torchvision\n", + "import torch.nn.functional as F\n", + "from torch.utils.data.sampler import SubsetRandomSampler\n", + "from torch.utils.data import DataLoader\n", + "from torchvision import datasets, models, transforms\n", + "from torchvision.transforms import Resize, ToTensor, Normalize\n", + "import matplotlib.pyplot as plt\n", + "from imblearn.under_sampling import RandomUnderSampler\n", + "\n", + "from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, roc_auc_score, \\\n", + " average_precision_score\n", + "from sklearn.model_selection import train_test_split\n", + "import time\n", + "import os\n", + "from pathlib import Path\n", + "from skimage import io\n", + "import copy\n", + "from torch import optim, cuda\n", + "import pandas as pd\n", + "import glob\n", + "from collections import Counter\n", + "# Useful for examining network\n", + "from functools import reduce\n", + "from operator import __add__\n", + "# from torchsummary import summary\n", + "import seaborn as sns\n", + "import warnings\n", + "# warnings.filterwarnings('ignore', category=FutureWarning)\n", + "from PIL import Image\n", + "from timeit import default_timer as timer\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# Useful for examining network\n", + "from functools import reduce\n", + "from operator import __add__\n", + "from torchsummary import summary\n", + "\n", + "# from IPython.core.interactiveshell import InteractiveShell\n", + "import seaborn as sns\n", + "\n", + "import warnings\n", + "# warnings.filterwarnings('ignore', category=FutureWarning)\n", + "\n", + "# Image manipulations\n", + "from PIL import Image\n", + "\n", + "# Timing utility\n", + "from timeit import default_timer as timer\n", + "\n", + "# Visualizations\n", + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "class Multi_2D_CNN_block(nn.Module):\n", + " def __init__(self, in_channels, num_kernel):\n", + " super(Multi_2D_CNN_block, self).__init__()\n", + " conv_block = BasicConv2d\n", + " self.a = conv_block(in_channels, int(num_kernel / 3), kernel_size=(1, 1))\n", + "\n", + " self.b = nn.Sequential(\n", + " conv_block(in_channels, int(num_kernel / 2), kernel_size=(1, 1)),\n", + " conv_block(int(num_kernel / 2), int(num_kernel), kernel_size=(3, 3))\n", + " )\n", + "\n", + " self.c = nn.Sequential(\n", + " conv_block(in_channels, int(num_kernel / 3), kernel_size=(1, 1)),\n", + " conv_block(int(num_kernel / 3), int(num_kernel / 2), kernel_size=(3, 3)),\n", + " conv_block(int(num_kernel / 2), int(num_kernel), kernel_size=(3, 3))\n", + " )\n", + " self.out_channels = int(num_kernel / 3) + int(num_kernel) + int(num_kernel)\n", + " # I get out_channels is total number of out_channels for a/b/c\n", + " self.bn = nn.BatchNorm2d(self.out_channels)\n", + "\n", + " def get_out_channels(self):\n", + " return self.out_channels\n", + "\n", + " def forward(self, x):\n", + " branch1 = self.a(x)\n", + " branch2 = self.b(x)\n", + " branch3 = self.c(x)\n", + " output = [branch1, branch2, branch3]\n", + " return self.bn(torch.cat(output,\n", + " 1)) # BatchNorm across the concatenation of output channels from final layer of Branch 1/2/3\n", + " # ,1 refers to the channel dimension\n", + "\n", + "\n", + "class BasicConv2d(nn.Module):\n", + " def __init__(self, in_channels, out_channels, kernel_size, **kwargs):\n", + " super(BasicConv2d, self).__init__()\n", + " conv_padding = reduce(__add__, [(k // 2 + (k - 2 * (k // 2)) - 1, k // 2) for k in kernel_size[::-1]])\n", + " self.pad = nn.ZeroPad2d(conv_padding)\n", + " # ZeroPad2d Output: :math:`(N, C, H_{out}, W_{out})` H_{out} is H_{in} with the padding to be added to either side of height\n", + " # ZeroPad2d(2) would add 2 to all 4 sides, ZeroPad2d((1,1,2,0)) would add 1 left, 1 right, 2 above, 0 below\n", + " # n_output_features = floor((n_input_features + 2(paddingsize) - convkernel_size) / stride_size) + 1\n", + " # above creates same padding\n", + " self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=False, **kwargs)\n", + " self.bn = nn.BatchNorm2d(out_channels)\n", + "\n", + " def forward(self, x):\n", + " x = self.pad(x)\n", + " x = self.conv(x)\n", + " x = F.relu(x, inplace=True)\n", + " x = self.bn(x)\n", + " return x\n", + "\n", + "\n", + "class MyModel(nn.Module):\n", + "\n", + " def __init__(self, initial_kernel_num=64):\n", + " super(MyModel, self).__init__()\n", + "\n", + " multi_2d_cnn = Multi_2D_CNN_block\n", + " conv_block = BasicConv2d\n", + "\n", + " self.conv_1 = conv_block(1, 64, kernel_size=(7, 3), stride=(2, 1))\n", + "\n", + " self.multi_2d_cnn_1a = nn.Sequential(\n", + " multi_2d_cnn(in_channels=64, num_kernel=initial_kernel_num),\n", + " multi_2d_cnn(in_channels=149, num_kernel=initial_kernel_num),\n", + " nn.MaxPool2d(kernel_size=(3, 1))\n", + " )\n", + "\n", + " self.multi_2d_cnn_1b = nn.Sequential(\n", + " multi_2d_cnn(in_channels=149, num_kernel=initial_kernel_num * 1.5),\n", + " multi_2d_cnn(in_channels=224, num_kernel=initial_kernel_num * 1.5),\n", + " nn.MaxPool2d(kernel_size=(3, 1))\n", + " )\n", + " self.multi_2d_cnn_1c = nn.Sequential(\n", + " multi_2d_cnn(in_channels=224, num_kernel=initial_kernel_num * 2),\n", + " multi_2d_cnn(in_channels=298, num_kernel=initial_kernel_num * 2),\n", + " nn.MaxPool2d(kernel_size=(2, 1))\n", + " )\n", + "\n", + " self.multi_2d_cnn_2a = nn.Sequential(\n", + " multi_2d_cnn(in_channels=298, num_kernel=initial_kernel_num * 3),\n", + " multi_2d_cnn(in_channels=448, num_kernel=initial_kernel_num * 3),\n", + " multi_2d_cnn(in_channels=448, num_kernel=initial_kernel_num * 4),\n", + " nn.MaxPool2d(kernel_size=(2, 1))\n", + " )\n", + " self.multi_2d_cnn_2b = nn.Sequential(\n", + " multi_2d_cnn(in_channels=597, num_kernel=initial_kernel_num * 5),\n", + " multi_2d_cnn(in_channels=746, num_kernel=initial_kernel_num * 6),\n", + " multi_2d_cnn(in_channels=896, num_kernel=initial_kernel_num * 7),\n", + " nn.MaxPool2d(kernel_size=(2, 1))\n", + " )\n", + " self.multi_2d_cnn_2c = nn.Sequential(\n", + " multi_2d_cnn(in_channels=1045, num_kernel=initial_kernel_num * 8),\n", + " multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 8),\n", + " multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 8),\n", + " nn.MaxPool2d(kernel_size=(2, 1))\n", + " )\n", + " self.multi_2d_cnn_2d = nn.Sequential(\n", + " multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 12),\n", + " multi_2d_cnn(in_channels=1792, num_kernel=initial_kernel_num * 14),\n", + " multi_2d_cnn(in_channels=2090, num_kernel=initial_kernel_num * 16),\n", + " )\n", + " self.output = nn.Sequential(\n", + " nn.AdaptiveAvgPool2d((1, 1)),\n", + " nn.Flatten(),\n", + " nn.Dropout(0.5),\n", + " nn.Linear(2389, 1),\n", + " # nn.Sigmoid()\n", + " )\n", + "\n", + " def forward(self, x):\n", + " x = self.conv_1(x)\n", + " # N x 1250 x 12 x 64 tensor\n", + " x = self.multi_2d_cnn_1a(x)\n", + " # N x 416 x 12 x 149 tensor\n", + " x = self.multi_2d_cnn_1b(x)\n", + " # N x 138 x 12 x 224 tensor\n", + " x = self.multi_2d_cnn_1c(x)\n", + " # N x 69 x 12 x 298\n", + " x = self.multi_2d_cnn_2a(x)\n", + "\n", + " x = self.multi_2d_cnn_2b(x)\n", + "\n", + " x = self.multi_2d_cnn_2c(x)\n", + "\n", + " x = self.multi_2d_cnn_2d(x)\n", + "\n", + " x = self.output(x)\n", + "\n", + " return x\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train on: cpu\n" + ] + } + ], + "source": [ + "\n", + "model = MyModel()\n", + "\n", + "# Device check\n", + "device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f'Train on: {device}')" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------------------------------------\n", + " Layer (type) Output Shape Param #\n", + "================================================================\n", + " ZeroPad2d-1 [-1, 1, 2506, 14] 0\n", + " Conv2d-2 [-1, 64, 1250, 12] 1,344\n", + " BatchNorm2d-3 [-1, 64, 1250, 12] 128\n", + " BasicConv2d-4 [-1, 64, 1250, 12] 0\n", + " ZeroPad2d-5 [-1, 64, 1250, 12] 0\n", + " Conv2d-6 [-1, 21, 1250, 12] 1,344\n", + " BatchNorm2d-7 [-1, 21, 1250, 12] 42\n", + " BasicConv2d-8 [-1, 21, 1250, 12] 0\n", + " ZeroPad2d-9 [-1, 64, 1250, 12] 0\n", + " Conv2d-10 [-1, 32, 1250, 12] 2,048\n", + " BatchNorm2d-11 [-1, 32, 1250, 12] 64\n", + " BasicConv2d-12 [-1, 32, 1250, 12] 0\n", + " ZeroPad2d-13 [-1, 32, 1252, 14] 0\n", + " Conv2d-14 [-1, 64, 1250, 12] 18,432\n", + " BatchNorm2d-15 [-1, 64, 1250, 12] 128\n", + " BasicConv2d-16 [-1, 64, 1250, 12] 0\n", + " ZeroPad2d-17 [-1, 64, 1250, 12] 0\n", + " Conv2d-18 [-1, 21, 1250, 12] 1,344\n", + " BatchNorm2d-19 [-1, 21, 1250, 12] 42\n", + " BasicConv2d-20 [-1, 21, 1250, 12] 0\n", + " ZeroPad2d-21 [-1, 21, 1252, 14] 0\n", + " Conv2d-22 [-1, 32, 1250, 12] 6,048\n", + " BatchNorm2d-23 [-1, 32, 1250, 12] 64\n", + " BasicConv2d-24 [-1, 32, 1250, 12] 0\n", + " ZeroPad2d-25 [-1, 32, 1252, 14] 0\n", + " Conv2d-26 [-1, 64, 1250, 12] 18,432\n", + " BatchNorm2d-27 [-1, 64, 1250, 12] 128\n", + " BasicConv2d-28 [-1, 64, 1250, 12] 0\n", + " BatchNorm2d-29 [-1, 149, 1250, 12] 298\n", + "Multi_2D_CNN_block-30 [-1, 149, 1250, 12] 0\n", + " ZeroPad2d-31 [-1, 149, 1250, 12] 0\n", + " Conv2d-32 [-1, 21, 1250, 12] 3,129\n", + " BatchNorm2d-33 [-1, 21, 1250, 12] 42\n", + " BasicConv2d-34 [-1, 21, 1250, 12] 0\n", + " ZeroPad2d-35 [-1, 149, 1250, 12] 0\n", + " Conv2d-36 [-1, 32, 1250, 12] 4,768\n", + " BatchNorm2d-37 [-1, 32, 1250, 12] 64\n", + " BasicConv2d-38 [-1, 32, 1250, 12] 0\n", + " ZeroPad2d-39 [-1, 32, 1252, 14] 0\n", + " Conv2d-40 [-1, 64, 1250, 12] 18,432\n", + " BatchNorm2d-41 [-1, 64, 1250, 12] 128\n", + " BasicConv2d-42 [-1, 64, 1250, 12] 0\n", + " ZeroPad2d-43 [-1, 149, 1250, 12] 0\n", + " Conv2d-44 [-1, 21, 1250, 12] 3,129\n", + " BatchNorm2d-45 [-1, 21, 1250, 12] 42\n", + " BasicConv2d-46 [-1, 21, 1250, 12] 0\n", + " ZeroPad2d-47 [-1, 21, 1252, 14] 0\n", + " Conv2d-48 [-1, 32, 1250, 12] 6,048\n", + " BatchNorm2d-49 [-1, 32, 1250, 12] 64\n", + " BasicConv2d-50 [-1, 32, 1250, 12] 0\n", + " ZeroPad2d-51 [-1, 32, 1252, 14] 0\n", + " Conv2d-52 [-1, 64, 1250, 12] 18,432\n", + " BatchNorm2d-53 [-1, 64, 1250, 12] 128\n", + " BasicConv2d-54 [-1, 64, 1250, 12] 0\n", + " BatchNorm2d-55 [-1, 149, 1250, 12] 298\n", + "Multi_2D_CNN_block-56 [-1, 149, 1250, 12] 0\n", + " MaxPool2d-57 [-1, 149, 416, 12] 0\n", + " ZeroPad2d-58 [-1, 149, 416, 12] 0\n", + " Conv2d-59 [-1, 32, 416, 12] 4,768\n", + " BatchNorm2d-60 [-1, 32, 416, 12] 64\n", + " BasicConv2d-61 [-1, 32, 416, 12] 0\n", + " ZeroPad2d-62 [-1, 149, 416, 12] 0\n", + " Conv2d-63 [-1, 48, 416, 12] 7,152\n", + " BatchNorm2d-64 [-1, 48, 416, 12] 96\n", + " BasicConv2d-65 [-1, 48, 416, 12] 0\n", + " ZeroPad2d-66 [-1, 48, 418, 14] 0\n", + " Conv2d-67 [-1, 96, 416, 12] 41,472\n", + " BatchNorm2d-68 [-1, 96, 416, 12] 192\n", + " BasicConv2d-69 [-1, 96, 416, 12] 0\n", + " ZeroPad2d-70 [-1, 149, 416, 12] 0\n", + " Conv2d-71 [-1, 32, 416, 12] 4,768\n", + " BatchNorm2d-72 [-1, 32, 416, 12] 64\n", + " BasicConv2d-73 [-1, 32, 416, 12] 0\n", + " ZeroPad2d-74 [-1, 32, 418, 14] 0\n", + " Conv2d-75 [-1, 48, 416, 12] 13,824\n", + " BatchNorm2d-76 [-1, 48, 416, 12] 96\n", + " BasicConv2d-77 [-1, 48, 416, 12] 0\n", + " ZeroPad2d-78 [-1, 48, 418, 14] 0\n", + " Conv2d-79 [-1, 96, 416, 12] 41,472\n", + " BatchNorm2d-80 [-1, 96, 416, 12] 192\n", + " BasicConv2d-81 [-1, 96, 416, 12] 0\n", + " BatchNorm2d-82 [-1, 224, 416, 12] 448\n", + "Multi_2D_CNN_block-83 [-1, 224, 416, 12] 0\n", + " ZeroPad2d-84 [-1, 224, 416, 12] 0\n", + " Conv2d-85 [-1, 32, 416, 12] 7,168\n", + " BatchNorm2d-86 [-1, 32, 416, 12] 64\n", + " BasicConv2d-87 [-1, 32, 416, 12] 0\n", + " ZeroPad2d-88 [-1, 224, 416, 12] 0\n", + " Conv2d-89 [-1, 48, 416, 12] 10,752\n", + " BatchNorm2d-90 [-1, 48, 416, 12] 96\n", + " BasicConv2d-91 [-1, 48, 416, 12] 0\n", + " ZeroPad2d-92 [-1, 48, 418, 14] 0\n", + " Conv2d-93 [-1, 96, 416, 12] 41,472\n", + " BatchNorm2d-94 [-1, 96, 416, 12] 192\n", + " BasicConv2d-95 [-1, 96, 416, 12] 0\n", + " ZeroPad2d-96 [-1, 224, 416, 12] 0\n", + " Conv2d-97 [-1, 32, 416, 12] 7,168\n", + " BatchNorm2d-98 [-1, 32, 416, 12] 64\n", + " BasicConv2d-99 [-1, 32, 416, 12] 0\n", + " ZeroPad2d-100 [-1, 32, 418, 14] 0\n", + " Conv2d-101 [-1, 48, 416, 12] 13,824\n", + " BatchNorm2d-102 [-1, 48, 416, 12] 96\n", + " BasicConv2d-103 [-1, 48, 416, 12] 0\n", + " ZeroPad2d-104 [-1, 48, 418, 14] 0\n", + " Conv2d-105 [-1, 96, 416, 12] 41,472\n", + " BatchNorm2d-106 [-1, 96, 416, 12] 192\n", + " BasicConv2d-107 [-1, 96, 416, 12] 0\n", + " BatchNorm2d-108 [-1, 224, 416, 12] 448\n", + "Multi_2D_CNN_block-109 [-1, 224, 416, 12] 0\n", + " MaxPool2d-110 [-1, 224, 138, 12] 0\n", + " ZeroPad2d-111 [-1, 224, 138, 12] 0\n", + " Conv2d-112 [-1, 42, 138, 12] 9,408\n", + " BatchNorm2d-113 [-1, 42, 138, 12] 84\n", + " BasicConv2d-114 [-1, 42, 138, 12] 0\n", + " ZeroPad2d-115 [-1, 224, 138, 12] 0\n", + " Conv2d-116 [-1, 64, 138, 12] 14,336\n", + " BatchNorm2d-117 [-1, 64, 138, 12] 128\n", + " BasicConv2d-118 [-1, 64, 138, 12] 0\n", + " ZeroPad2d-119 [-1, 64, 140, 14] 0\n", + " Conv2d-120 [-1, 128, 138, 12] 73,728\n", + " BatchNorm2d-121 [-1, 128, 138, 12] 256\n", + " BasicConv2d-122 [-1, 128, 138, 12] 0\n", + " ZeroPad2d-123 [-1, 224, 138, 12] 0\n", + " Conv2d-124 [-1, 42, 138, 12] 9,408\n", + " BatchNorm2d-125 [-1, 42, 138, 12] 84\n", + " BasicConv2d-126 [-1, 42, 138, 12] 0\n", + " ZeroPad2d-127 [-1, 42, 140, 14] 0\n", + " Conv2d-128 [-1, 64, 138, 12] 24,192\n", + " BatchNorm2d-129 [-1, 64, 138, 12] 128\n", + " BasicConv2d-130 [-1, 64, 138, 12] 0\n", + " ZeroPad2d-131 [-1, 64, 140, 14] 0\n", + " Conv2d-132 [-1, 128, 138, 12] 73,728\n", + " BatchNorm2d-133 [-1, 128, 138, 12] 256\n", + " BasicConv2d-134 [-1, 128, 138, 12] 0\n", + " BatchNorm2d-135 [-1, 298, 138, 12] 596\n", + "Multi_2D_CNN_block-136 [-1, 298, 138, 12] 0\n", + " ZeroPad2d-137 [-1, 298, 138, 12] 0\n", + " Conv2d-138 [-1, 42, 138, 12] 12,516\n", + " BatchNorm2d-139 [-1, 42, 138, 12] 84\n", + " BasicConv2d-140 [-1, 42, 138, 12] 0\n", + " ZeroPad2d-141 [-1, 298, 138, 12] 0\n", + " Conv2d-142 [-1, 64, 138, 12] 19,072\n", + " BatchNorm2d-143 [-1, 64, 138, 12] 128\n", + " BasicConv2d-144 [-1, 64, 138, 12] 0\n", + " ZeroPad2d-145 [-1, 64, 140, 14] 0\n", + " Conv2d-146 [-1, 128, 138, 12] 73,728\n", + " BatchNorm2d-147 [-1, 128, 138, 12] 256\n", + " BasicConv2d-148 [-1, 128, 138, 12] 0\n", + " ZeroPad2d-149 [-1, 298, 138, 12] 0\n", + " Conv2d-150 [-1, 42, 138, 12] 12,516\n", + " BatchNorm2d-151 [-1, 42, 138, 12] 84\n", + " BasicConv2d-152 [-1, 42, 138, 12] 0\n", + " ZeroPad2d-153 [-1, 42, 140, 14] 0\n", + " Conv2d-154 [-1, 64, 138, 12] 24,192\n", + " BatchNorm2d-155 [-1, 64, 138, 12] 128\n", + " BasicConv2d-156 [-1, 64, 138, 12] 0\n", + " ZeroPad2d-157 [-1, 64, 140, 14] 0\n", + " Conv2d-158 [-1, 128, 138, 12] 73,728\n", + " BatchNorm2d-159 [-1, 128, 138, 12] 256\n", + " BasicConv2d-160 [-1, 128, 138, 12] 0\n", + " BatchNorm2d-161 [-1, 298, 138, 12] 596\n", + "Multi_2D_CNN_block-162 [-1, 298, 138, 12] 0\n", + " MaxPool2d-163 [-1, 298, 69, 12] 0\n", + " ZeroPad2d-164 [-1, 298, 69, 12] 0\n", + " Conv2d-165 [-1, 64, 69, 12] 19,072\n", + " BatchNorm2d-166 [-1, 64, 69, 12] 128\n", + " BasicConv2d-167 [-1, 64, 69, 12] 0\n", + " ZeroPad2d-168 [-1, 298, 69, 12] 0\n", + " Conv2d-169 [-1, 96, 69, 12] 28,608\n", + " BatchNorm2d-170 [-1, 96, 69, 12] 192\n", + " BasicConv2d-171 [-1, 96, 69, 12] 0\n", + " ZeroPad2d-172 [-1, 96, 71, 14] 0\n", + " Conv2d-173 [-1, 192, 69, 12] 165,888\n", + " BatchNorm2d-174 [-1, 192, 69, 12] 384\n", + " BasicConv2d-175 [-1, 192, 69, 12] 0\n", + " ZeroPad2d-176 [-1, 298, 69, 12] 0\n", + " Conv2d-177 [-1, 64, 69, 12] 19,072\n", + " BatchNorm2d-178 [-1, 64, 69, 12] 128\n", + " BasicConv2d-179 [-1, 64, 69, 12] 0\n", + " ZeroPad2d-180 [-1, 64, 71, 14] 0\n", + " Conv2d-181 [-1, 96, 69, 12] 55,296\n", + " BatchNorm2d-182 [-1, 96, 69, 12] 192\n", + " BasicConv2d-183 [-1, 96, 69, 12] 0\n", + " ZeroPad2d-184 [-1, 96, 71, 14] 0\n", + " Conv2d-185 [-1, 192, 69, 12] 165,888\n", + " BatchNorm2d-186 [-1, 192, 69, 12] 384\n", + " BasicConv2d-187 [-1, 192, 69, 12] 0\n", + " BatchNorm2d-188 [-1, 448, 69, 12] 896\n", + "Multi_2D_CNN_block-189 [-1, 448, 69, 12] 0\n", + " ZeroPad2d-190 [-1, 448, 69, 12] 0\n", + " Conv2d-191 [-1, 64, 69, 12] 28,672\n", + " BatchNorm2d-192 [-1, 64, 69, 12] 128\n", + " BasicConv2d-193 [-1, 64, 69, 12] 0\n", + " ZeroPad2d-194 [-1, 448, 69, 12] 0\n", + " Conv2d-195 [-1, 96, 69, 12] 43,008\n", + " BatchNorm2d-196 [-1, 96, 69, 12] 192\n", + " BasicConv2d-197 [-1, 96, 69, 12] 0\n", + " ZeroPad2d-198 [-1, 96, 71, 14] 0\n", + " Conv2d-199 [-1, 192, 69, 12] 165,888\n", + " BatchNorm2d-200 [-1, 192, 69, 12] 384\n", + " BasicConv2d-201 [-1, 192, 69, 12] 0\n", + " ZeroPad2d-202 [-1, 448, 69, 12] 0\n", + " Conv2d-203 [-1, 64, 69, 12] 28,672\n", + " BatchNorm2d-204 [-1, 64, 69, 12] 128\n", + " BasicConv2d-205 [-1, 64, 69, 12] 0\n", + " ZeroPad2d-206 [-1, 64, 71, 14] 0\n", + " Conv2d-207 [-1, 96, 69, 12] 55,296\n", + " BatchNorm2d-208 [-1, 96, 69, 12] 192\n", + " BasicConv2d-209 [-1, 96, 69, 12] 0\n", + " ZeroPad2d-210 [-1, 96, 71, 14] 0\n", + " Conv2d-211 [-1, 192, 69, 12] 165,888\n", + " BatchNorm2d-212 [-1, 192, 69, 12] 384\n", + " BasicConv2d-213 [-1, 192, 69, 12] 0\n", + " BatchNorm2d-214 [-1, 448, 69, 12] 896\n", + "Multi_2D_CNN_block-215 [-1, 448, 69, 12] 0\n", + " ZeroPad2d-216 [-1, 448, 69, 12] 0\n", + " Conv2d-217 [-1, 85, 69, 12] 38,080\n", + " BatchNorm2d-218 [-1, 85, 69, 12] 170\n", + " BasicConv2d-219 [-1, 85, 69, 12] 0\n", + " ZeroPad2d-220 [-1, 448, 69, 12] 0\n", + " Conv2d-221 [-1, 128, 69, 12] 57,344\n", + " BatchNorm2d-222 [-1, 128, 69, 12] 256\n", + " BasicConv2d-223 [-1, 128, 69, 12] 0\n", + " ZeroPad2d-224 [-1, 128, 71, 14] 0\n", + " Conv2d-225 [-1, 256, 69, 12] 294,912\n", + " BatchNorm2d-226 [-1, 256, 69, 12] 512\n", + " BasicConv2d-227 [-1, 256, 69, 12] 0\n", + " ZeroPad2d-228 [-1, 448, 69, 12] 0\n", + " Conv2d-229 [-1, 85, 69, 12] 38,080\n", + " BatchNorm2d-230 [-1, 85, 69, 12] 170\n", + " BasicConv2d-231 [-1, 85, 69, 12] 0\n", + " ZeroPad2d-232 [-1, 85, 71, 14] 0\n", + " Conv2d-233 [-1, 128, 69, 12] 97,920\n", + " BatchNorm2d-234 [-1, 128, 69, 12] 256\n", + " BasicConv2d-235 [-1, 128, 69, 12] 0\n", + " ZeroPad2d-236 [-1, 128, 71, 14] 0\n", + " Conv2d-237 [-1, 256, 69, 12] 294,912\n", + " BatchNorm2d-238 [-1, 256, 69, 12] 512\n", + " BasicConv2d-239 [-1, 256, 69, 12] 0\n", + " BatchNorm2d-240 [-1, 597, 69, 12] 1,194\n", + "Multi_2D_CNN_block-241 [-1, 597, 69, 12] 0\n", + " MaxPool2d-242 [-1, 597, 34, 12] 0\n", + " ZeroPad2d-243 [-1, 597, 34, 12] 0\n", + " Conv2d-244 [-1, 106, 34, 12] 63,282\n", + " BatchNorm2d-245 [-1, 106, 34, 12] 212\n", + " BasicConv2d-246 [-1, 106, 34, 12] 0\n", + " ZeroPad2d-247 [-1, 597, 34, 12] 0\n", + " Conv2d-248 [-1, 160, 34, 12] 95,520\n", + " BatchNorm2d-249 [-1, 160, 34, 12] 320\n", + " BasicConv2d-250 [-1, 160, 34, 12] 0\n", + " ZeroPad2d-251 [-1, 160, 36, 14] 0\n", + " Conv2d-252 [-1, 320, 34, 12] 460,800\n", + " BatchNorm2d-253 [-1, 320, 34, 12] 640\n", + " BasicConv2d-254 [-1, 320, 34, 12] 0\n", + " ZeroPad2d-255 [-1, 597, 34, 12] 0\n", + " Conv2d-256 [-1, 106, 34, 12] 63,282\n", + " BatchNorm2d-257 [-1, 106, 34, 12] 212\n", + " BasicConv2d-258 [-1, 106, 34, 12] 0\n", + " ZeroPad2d-259 [-1, 106, 36, 14] 0\n", + " Conv2d-260 [-1, 160, 34, 12] 152,640\n", + " BatchNorm2d-261 [-1, 160, 34, 12] 320\n", + " BasicConv2d-262 [-1, 160, 34, 12] 0\n", + " ZeroPad2d-263 [-1, 160, 36, 14] 0\n", + " Conv2d-264 [-1, 320, 34, 12] 460,800\n", + " BatchNorm2d-265 [-1, 320, 34, 12] 640\n", + " BasicConv2d-266 [-1, 320, 34, 12] 0\n", + " BatchNorm2d-267 [-1, 746, 34, 12] 1,492\n", + "Multi_2D_CNN_block-268 [-1, 746, 34, 12] 0\n", + " ZeroPad2d-269 [-1, 746, 34, 12] 0\n", + " Conv2d-270 [-1, 128, 34, 12] 95,488\n", + " BatchNorm2d-271 [-1, 128, 34, 12] 256\n", + " BasicConv2d-272 [-1, 128, 34, 12] 0\n", + " ZeroPad2d-273 [-1, 746, 34, 12] 0\n", + " Conv2d-274 [-1, 192, 34, 12] 143,232\n", + " BatchNorm2d-275 [-1, 192, 34, 12] 384\n", + " BasicConv2d-276 [-1, 192, 34, 12] 0\n", + " ZeroPad2d-277 [-1, 192, 36, 14] 0\n", + " Conv2d-278 [-1, 384, 34, 12] 663,552\n", + " BatchNorm2d-279 [-1, 384, 34, 12] 768\n", + " BasicConv2d-280 [-1, 384, 34, 12] 0\n", + " ZeroPad2d-281 [-1, 746, 34, 12] 0\n", + " Conv2d-282 [-1, 128, 34, 12] 95,488\n", + " BatchNorm2d-283 [-1, 128, 34, 12] 256\n", + " BasicConv2d-284 [-1, 128, 34, 12] 0\n", + " ZeroPad2d-285 [-1, 128, 36, 14] 0\n", + " Conv2d-286 [-1, 192, 34, 12] 221,184\n", + " BatchNorm2d-287 [-1, 192, 34, 12] 384\n", + " BasicConv2d-288 [-1, 192, 34, 12] 0\n", + " ZeroPad2d-289 [-1, 192, 36, 14] 0\n", + " Conv2d-290 [-1, 384, 34, 12] 663,552\n", + " BatchNorm2d-291 [-1, 384, 34, 12] 768\n", + " BasicConv2d-292 [-1, 384, 34, 12] 0\n", + " BatchNorm2d-293 [-1, 896, 34, 12] 1,792\n", + "Multi_2D_CNN_block-294 [-1, 896, 34, 12] 0\n", + " ZeroPad2d-295 [-1, 896, 34, 12] 0\n", + " Conv2d-296 [-1, 149, 34, 12] 133,504\n", + " BatchNorm2d-297 [-1, 149, 34, 12] 298\n", + " BasicConv2d-298 [-1, 149, 34, 12] 0\n", + " ZeroPad2d-299 [-1, 896, 34, 12] 0\n", + " Conv2d-300 [-1, 224, 34, 12] 200,704\n", + " BatchNorm2d-301 [-1, 224, 34, 12] 448\n", + " BasicConv2d-302 [-1, 224, 34, 12] 0\n", + " ZeroPad2d-303 [-1, 224, 36, 14] 0\n", + " Conv2d-304 [-1, 448, 34, 12] 903,168\n", + " BatchNorm2d-305 [-1, 448, 34, 12] 896\n", + " BasicConv2d-306 [-1, 448, 34, 12] 0\n", + " ZeroPad2d-307 [-1, 896, 34, 12] 0\n", + " Conv2d-308 [-1, 149, 34, 12] 133,504\n", + " BatchNorm2d-309 [-1, 149, 34, 12] 298\n", + " BasicConv2d-310 [-1, 149, 34, 12] 0\n", + " ZeroPad2d-311 [-1, 149, 36, 14] 0\n", + " Conv2d-312 [-1, 224, 34, 12] 300,384\n", + " BatchNorm2d-313 [-1, 224, 34, 12] 448\n", + " BasicConv2d-314 [-1, 224, 34, 12] 0\n", + " ZeroPad2d-315 [-1, 224, 36, 14] 0\n", + " Conv2d-316 [-1, 448, 34, 12] 903,168\n", + " BatchNorm2d-317 [-1, 448, 34, 12] 896\n", + " BasicConv2d-318 [-1, 448, 34, 12] 0\n", + " BatchNorm2d-319 [-1, 1045, 34, 12] 2,090\n", + "Multi_2D_CNN_block-320 [-1, 1045, 34, 12] 0\n", + " MaxPool2d-321 [-1, 1045, 17, 12] 0\n", + " ZeroPad2d-322 [-1, 1045, 17, 12] 0\n", + " Conv2d-323 [-1, 170, 17, 12] 177,650\n", + " BatchNorm2d-324 [-1, 170, 17, 12] 340\n", + " BasicConv2d-325 [-1, 170, 17, 12] 0\n", + " ZeroPad2d-326 [-1, 1045, 17, 12] 0\n", + " Conv2d-327 [-1, 256, 17, 12] 267,520\n", + " BatchNorm2d-328 [-1, 256, 17, 12] 512\n", + " BasicConv2d-329 [-1, 256, 17, 12] 0\n", + " ZeroPad2d-330 [-1, 256, 19, 14] 0\n", + " Conv2d-331 [-1, 512, 17, 12] 1,179,648\n", + " BatchNorm2d-332 [-1, 512, 17, 12] 1,024\n", + " BasicConv2d-333 [-1, 512, 17, 12] 0\n", + " ZeroPad2d-334 [-1, 1045, 17, 12] 0\n", + " Conv2d-335 [-1, 170, 17, 12] 177,650\n", + " BatchNorm2d-336 [-1, 170, 17, 12] 340\n", + " BasicConv2d-337 [-1, 170, 17, 12] 0\n", + " ZeroPad2d-338 [-1, 170, 19, 14] 0\n", + " Conv2d-339 [-1, 256, 17, 12] 391,680\n", + " BatchNorm2d-340 [-1, 256, 17, 12] 512\n", + " BasicConv2d-341 [-1, 256, 17, 12] 0\n", + " ZeroPad2d-342 [-1, 256, 19, 14] 0\n", + " Conv2d-343 [-1, 512, 17, 12] 1,179,648\n", + " BatchNorm2d-344 [-1, 512, 17, 12] 1,024\n", + " BasicConv2d-345 [-1, 512, 17, 12] 0\n", + " BatchNorm2d-346 [-1, 1194, 17, 12] 2,388\n", + "Multi_2D_CNN_block-347 [-1, 1194, 17, 12] 0\n", + " ZeroPad2d-348 [-1, 1194, 17, 12] 0\n", + " Conv2d-349 [-1, 170, 17, 12] 202,980\n", + " BatchNorm2d-350 [-1, 170, 17, 12] 340\n", + " BasicConv2d-351 [-1, 170, 17, 12] 0\n", + " ZeroPad2d-352 [-1, 1194, 17, 12] 0\n", + " Conv2d-353 [-1, 256, 17, 12] 305,664\n", + " BatchNorm2d-354 [-1, 256, 17, 12] 512\n", + " BasicConv2d-355 [-1, 256, 17, 12] 0\n", + " ZeroPad2d-356 [-1, 256, 19, 14] 0\n", + " Conv2d-357 [-1, 512, 17, 12] 1,179,648\n", + " BatchNorm2d-358 [-1, 512, 17, 12] 1,024\n", + " BasicConv2d-359 [-1, 512, 17, 12] 0\n", + " ZeroPad2d-360 [-1, 1194, 17, 12] 0\n", + " Conv2d-361 [-1, 170, 17, 12] 202,980\n", + " BatchNorm2d-362 [-1, 170, 17, 12] 340\n", + " BasicConv2d-363 [-1, 170, 17, 12] 0\n", + " ZeroPad2d-364 [-1, 170, 19, 14] 0\n", + " Conv2d-365 [-1, 256, 17, 12] 391,680\n", + " BatchNorm2d-366 [-1, 256, 17, 12] 512\n", + " BasicConv2d-367 [-1, 256, 17, 12] 0\n", + " ZeroPad2d-368 [-1, 256, 19, 14] 0\n", + " Conv2d-369 [-1, 512, 17, 12] 1,179,648\n", + " BatchNorm2d-370 [-1, 512, 17, 12] 1,024\n", + " BasicConv2d-371 [-1, 512, 17, 12] 0\n", + " BatchNorm2d-372 [-1, 1194, 17, 12] 2,388\n", + "Multi_2D_CNN_block-373 [-1, 1194, 17, 12] 0\n", + " ZeroPad2d-374 [-1, 1194, 17, 12] 0\n", + " Conv2d-375 [-1, 170, 17, 12] 202,980\n", + " BatchNorm2d-376 [-1, 170, 17, 12] 340\n", + " BasicConv2d-377 [-1, 170, 17, 12] 0\n", + " ZeroPad2d-378 [-1, 1194, 17, 12] 0\n", + " Conv2d-379 [-1, 256, 17, 12] 305,664\n", + " BatchNorm2d-380 [-1, 256, 17, 12] 512\n", + " BasicConv2d-381 [-1, 256, 17, 12] 0\n", + " ZeroPad2d-382 [-1, 256, 19, 14] 0\n", + " Conv2d-383 [-1, 512, 17, 12] 1,179,648\n", + " BatchNorm2d-384 [-1, 512, 17, 12] 1,024\n", + " BasicConv2d-385 [-1, 512, 17, 12] 0\n", + " ZeroPad2d-386 [-1, 1194, 17, 12] 0\n", + " Conv2d-387 [-1, 170, 17, 12] 202,980\n", + " BatchNorm2d-388 [-1, 170, 17, 12] 340\n", + " BasicConv2d-389 [-1, 170, 17, 12] 0\n", + " ZeroPad2d-390 [-1, 170, 19, 14] 0\n", + " Conv2d-391 [-1, 256, 17, 12] 391,680\n", + " BatchNorm2d-392 [-1, 256, 17, 12] 512\n", + " BasicConv2d-393 [-1, 256, 17, 12] 0\n", + " ZeroPad2d-394 [-1, 256, 19, 14] 0\n", + " Conv2d-395 [-1, 512, 17, 12] 1,179,648\n", + " BatchNorm2d-396 [-1, 512, 17, 12] 1,024\n", + " BasicConv2d-397 [-1, 512, 17, 12] 0\n", + " BatchNorm2d-398 [-1, 1194, 17, 12] 2,388\n", + "Multi_2D_CNN_block-399 [-1, 1194, 17, 12] 0\n", + " MaxPool2d-400 [-1, 1194, 8, 12] 0\n", + " ZeroPad2d-401 [-1, 1194, 8, 12] 0\n", + " Conv2d-402 [-1, 256, 8, 12] 305,664\n", + " BatchNorm2d-403 [-1, 256, 8, 12] 512\n", + " BasicConv2d-404 [-1, 256, 8, 12] 0\n", + " ZeroPad2d-405 [-1, 1194, 8, 12] 0\n", + " Conv2d-406 [-1, 384, 8, 12] 458,496\n", + " BatchNorm2d-407 [-1, 384, 8, 12] 768\n", + " BasicConv2d-408 [-1, 384, 8, 12] 0\n", + " ZeroPad2d-409 [-1, 384, 10, 14] 0\n", + " Conv2d-410 [-1, 768, 8, 12] 2,654,208\n", + " BatchNorm2d-411 [-1, 768, 8, 12] 1,536\n", + " BasicConv2d-412 [-1, 768, 8, 12] 0\n", + " ZeroPad2d-413 [-1, 1194, 8, 12] 0\n", + " Conv2d-414 [-1, 256, 8, 12] 305,664\n", + " BatchNorm2d-415 [-1, 256, 8, 12] 512\n", + " BasicConv2d-416 [-1, 256, 8, 12] 0\n", + " ZeroPad2d-417 [-1, 256, 10, 14] 0\n", + " Conv2d-418 [-1, 384, 8, 12] 884,736\n", + " BatchNorm2d-419 [-1, 384, 8, 12] 768\n", + " BasicConv2d-420 [-1, 384, 8, 12] 0\n", + " ZeroPad2d-421 [-1, 384, 10, 14] 0\n", + " Conv2d-422 [-1, 768, 8, 12] 2,654,208\n", + " BatchNorm2d-423 [-1, 768, 8, 12] 1,536\n", + " BasicConv2d-424 [-1, 768, 8, 12] 0\n", + " BatchNorm2d-425 [-1, 1792, 8, 12] 3,584\n", + "Multi_2D_CNN_block-426 [-1, 1792, 8, 12] 0\n", + " ZeroPad2d-427 [-1, 1792, 8, 12] 0\n", + " Conv2d-428 [-1, 298, 8, 12] 534,016\n", + " BatchNorm2d-429 [-1, 298, 8, 12] 596\n", + " BasicConv2d-430 [-1, 298, 8, 12] 0\n", + " ZeroPad2d-431 [-1, 1792, 8, 12] 0\n", + " Conv2d-432 [-1, 448, 8, 12] 802,816\n", + " BatchNorm2d-433 [-1, 448, 8, 12] 896\n", + " BasicConv2d-434 [-1, 448, 8, 12] 0\n", + " ZeroPad2d-435 [-1, 448, 10, 14] 0\n", + " Conv2d-436 [-1, 896, 8, 12] 3,612,672\n", + " BatchNorm2d-437 [-1, 896, 8, 12] 1,792\n", + " BasicConv2d-438 [-1, 896, 8, 12] 0\n", + " ZeroPad2d-439 [-1, 1792, 8, 12] 0\n", + " Conv2d-440 [-1, 298, 8, 12] 534,016\n", + " BatchNorm2d-441 [-1, 298, 8, 12] 596\n", + " BasicConv2d-442 [-1, 298, 8, 12] 0\n", + " ZeroPad2d-443 [-1, 298, 10, 14] 0\n", + " Conv2d-444 [-1, 448, 8, 12] 1,201,536\n", + " BatchNorm2d-445 [-1, 448, 8, 12] 896\n", + " BasicConv2d-446 [-1, 448, 8, 12] 0\n", + " ZeroPad2d-447 [-1, 448, 10, 14] 0\n", + " Conv2d-448 [-1, 896, 8, 12] 3,612,672\n", + " BatchNorm2d-449 [-1, 896, 8, 12] 1,792\n", + " BasicConv2d-450 [-1, 896, 8, 12] 0\n", + " BatchNorm2d-451 [-1, 2090, 8, 12] 4,180\n", + "Multi_2D_CNN_block-452 [-1, 2090, 8, 12] 0\n", + " ZeroPad2d-453 [-1, 2090, 8, 12] 0\n", + " Conv2d-454 [-1, 341, 8, 12] 712,690\n", + " BatchNorm2d-455 [-1, 341, 8, 12] 682\n", + " BasicConv2d-456 [-1, 341, 8, 12] 0\n", + " ZeroPad2d-457 [-1, 2090, 8, 12] 0\n", + " Conv2d-458 [-1, 512, 8, 12] 1,070,080\n", + " BatchNorm2d-459 [-1, 512, 8, 12] 1,024\n", + " BasicConv2d-460 [-1, 512, 8, 12] 0\n", + " ZeroPad2d-461 [-1, 512, 10, 14] 0\n", + " Conv2d-462 [-1, 1024, 8, 12] 4,718,592\n", + " BatchNorm2d-463 [-1, 1024, 8, 12] 2,048\n", + " BasicConv2d-464 [-1, 1024, 8, 12] 0\n", + " ZeroPad2d-465 [-1, 2090, 8, 12] 0\n", + " Conv2d-466 [-1, 341, 8, 12] 712,690\n", + " BatchNorm2d-467 [-1, 341, 8, 12] 682\n", + " BasicConv2d-468 [-1, 341, 8, 12] 0\n", + " ZeroPad2d-469 [-1, 341, 10, 14] 0\n", + " Conv2d-470 [-1, 512, 8, 12] 1,571,328\n", + " BatchNorm2d-471 [-1, 512, 8, 12] 1,024\n", + " BasicConv2d-472 [-1, 512, 8, 12] 0\n", + " ZeroPad2d-473 [-1, 512, 10, 14] 0\n", + " Conv2d-474 [-1, 1024, 8, 12] 4,718,592\n", + " BatchNorm2d-475 [-1, 1024, 8, 12] 2,048\n", + " BasicConv2d-476 [-1, 1024, 8, 12] 0\n", + " BatchNorm2d-477 [-1, 2389, 8, 12] 4,778\n", + "Multi_2D_CNN_block-478 [-1, 2389, 8, 12] 0\n", + "AdaptiveAvgPool2d-479 [-1, 2389, 1, 1] 0\n", + " Flatten-480 [-1, 2389] 0\n", + " Dropout-481 [-1, 2389] 0\n", + " Linear-482 [-1, 1] 2,390\n", + "================================================================\n", + "Total params: 49,719,798\n", + "Trainable params: 49,719,798\n", + "Non-trainable params: 0\n", + "----------------------------------------------------------------\n", + "Input size (MB): 0.11\n", + "Forward/backward pass size (MB): 884.62\n", + "Params size (MB): 189.67\n", + "Estimated Total Size (MB): 1074.40\n", + "----------------------------------------------------------------\n" + ] + } + ], + "source": [ + "summary(model, input_size=(1, 2500, 12))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------------------------------------\n", + " Layer (type) Output Shape Param #\n", + "================================================================\n", + " ZeroPad2d-1 [-1, 1, 2506, 14] 0\n", + " Conv2d-2 [-1, 64, 1250, 12] 1,344\n", + " BatchNorm2d-3 [-1, 64, 1250, 12] 128\n", + " BasicConv2d-4 [-1, 64, 1250, 12] 0\n", + " ZeroPad2d-5 [-1, 64, 1250, 12] 0\n", + " Conv2d-6 [-1, 21, 1250, 12] 1,344\n", + " BatchNorm2d-7 [-1, 21, 1250, 12] 42\n", + " BasicConv2d-8 [-1, 21, 1250, 12] 0\n", + " ZeroPad2d-9 [-1, 64, 1250, 12] 0\n", + " Conv2d-10 [-1, 32, 1250, 12] 2,048\n", + " BatchNorm2d-11 [-1, 32, 1250, 12] 64\n", + " BasicConv2d-12 [-1, 32, 1250, 12] 0\n", + " ZeroPad2d-13 [-1, 32, 1252, 14] 0\n", + " Conv2d-14 [-1, 64, 1250, 12] 18,432\n", + " BatchNorm2d-15 [-1, 64, 1250, 12] 128\n", + " BasicConv2d-16 [-1, 64, 1250, 12] 0\n", + " ZeroPad2d-17 [-1, 64, 1250, 12] 0\n", + " Conv2d-18 [-1, 21, 1250, 12] 1,344\n", + " BatchNorm2d-19 [-1, 21, 1250, 12] 42\n", + " BasicConv2d-20 [-1, 21, 1250, 12] 0\n", + " ZeroPad2d-21 [-1, 21, 1252, 14] 0\n", + " Conv2d-22 [-1, 32, 1250, 12] 6,048\n", + " BatchNorm2d-23 [-1, 32, 1250, 12] 64\n", + " BasicConv2d-24 [-1, 32, 1250, 12] 0\n", + " ZeroPad2d-25 [-1, 32, 1252, 14] 0\n", + " Conv2d-26 [-1, 64, 1250, 12] 18,432\n", + " BatchNorm2d-27 [-1, 64, 1250, 12] 128\n", + " BasicConv2d-28 [-1, 64, 1250, 12] 0\n", + " BatchNorm2d-29 [-1, 149, 1250, 12] 298\n", + "Multi_2D_CNN_block-30 [-1, 149, 1250, 12] 0\n", + " ZeroPad2d-31 [-1, 149, 1250, 12] 0\n", + " Conv2d-32 [-1, 21, 1250, 12] 3,129\n", + " BatchNorm2d-33 [-1, 21, 1250, 12] 42\n", + " BasicConv2d-34 [-1, 21, 1250, 12] 0\n", + " ZeroPad2d-35 [-1, 149, 1250, 12] 0\n", + " Conv2d-36 [-1, 32, 1250, 12] 4,768\n", + " BatchNorm2d-37 [-1, 32, 1250, 12] 64\n", + " BasicConv2d-38 [-1, 32, 1250, 12] 0\n", + " ZeroPad2d-39 [-1, 32, 1252, 14] 0\n", + " Conv2d-40 [-1, 64, 1250, 12] 18,432\n", + " BatchNorm2d-41 [-1, 64, 1250, 12] 128\n", + " BasicConv2d-42 [-1, 64, 1250, 12] 0\n", + " ZeroPad2d-43 [-1, 149, 1250, 12] 0\n", + " Conv2d-44 [-1, 21, 1250, 12] 3,129\n", + " BatchNorm2d-45 [-1, 21, 1250, 12] 42\n", + " BasicConv2d-46 [-1, 21, 1250, 12] 0\n", + " ZeroPad2d-47 [-1, 21, 1252, 14] 0\n", + " Conv2d-48 [-1, 32, 1250, 12] 6,048\n", + " BatchNorm2d-49 [-1, 32, 1250, 12] 64\n", + " BasicConv2d-50 [-1, 32, 1250, 12] 0\n", + " ZeroPad2d-51 [-1, 32, 1252, 14] 0\n", + " Conv2d-52 [-1, 64, 1250, 12] 18,432\n", + " BatchNorm2d-53 [-1, 64, 1250, 12] 128\n", + " BasicConv2d-54 [-1, 64, 1250, 12] 0\n", + " BatchNorm2d-55 [-1, 149, 1250, 12] 298\n", + "Multi_2D_CNN_block-56 [-1, 149, 1250, 12] 0\n", + " MaxPool2d-57 [-1, 149, 416, 12] 0\n", + " ZeroPad2d-58 [-1, 149, 416, 12] 0\n", + " Conv2d-59 [-1, 32, 416, 12] 4,768\n", + " BatchNorm2d-60 [-1, 32, 416, 12] 64\n", + " BasicConv2d-61 [-1, 32, 416, 12] 0\n", + " ZeroPad2d-62 [-1, 149, 416, 12] 0\n", + " Conv2d-63 [-1, 48, 416, 12] 7,152\n", + " BatchNorm2d-64 [-1, 48, 416, 12] 96\n", + " BasicConv2d-65 [-1, 48, 416, 12] 0\n", + " ZeroPad2d-66 [-1, 48, 418, 14] 0\n", + " Conv2d-67 [-1, 96, 416, 12] 41,472\n", + " BatchNorm2d-68 [-1, 96, 416, 12] 192\n", + " BasicConv2d-69 [-1, 96, 416, 12] 0\n", + " ZeroPad2d-70 [-1, 149, 416, 12] 0\n", + " Conv2d-71 [-1, 32, 416, 12] 4,768\n", + " BatchNorm2d-72 [-1, 32, 416, 12] 64\n", + " BasicConv2d-73 [-1, 32, 416, 12] 0\n", + " ZeroPad2d-74 [-1, 32, 418, 14] 0\n", + " Conv2d-75 [-1, 48, 416, 12] 13,824\n", + " BatchNorm2d-76 [-1, 48, 416, 12] 96\n", + " BasicConv2d-77 [-1, 48, 416, 12] 0\n", + " ZeroPad2d-78 [-1, 48, 418, 14] 0\n", + " Conv2d-79 [-1, 96, 416, 12] 41,472\n", + " BatchNorm2d-80 [-1, 96, 416, 12] 192\n", + " BasicConv2d-81 [-1, 96, 416, 12] 0\n", + " BatchNorm2d-82 [-1, 224, 416, 12] 448\n", + "Multi_2D_CNN_block-83 [-1, 224, 416, 12] 0\n", + " ZeroPad2d-84 [-1, 224, 416, 12] 0\n", + " Conv2d-85 [-1, 32, 416, 12] 7,168\n", + " BatchNorm2d-86 [-1, 32, 416, 12] 64\n", + " BasicConv2d-87 [-1, 32, 416, 12] 0\n", + " ZeroPad2d-88 [-1, 224, 416, 12] 0\n", + " Conv2d-89 [-1, 48, 416, 12] 10,752\n", + " BatchNorm2d-90 [-1, 48, 416, 12] 96\n", + " BasicConv2d-91 [-1, 48, 416, 12] 0\n", + " ZeroPad2d-92 [-1, 48, 418, 14] 0\n", + " Conv2d-93 [-1, 96, 416, 12] 41,472\n", + " BatchNorm2d-94 [-1, 96, 416, 12] 192\n", + " BasicConv2d-95 [-1, 96, 416, 12] 0\n", + " ZeroPad2d-96 [-1, 224, 416, 12] 0\n", + " Conv2d-97 [-1, 32, 416, 12] 7,168\n", + " BatchNorm2d-98 [-1, 32, 416, 12] 64\n", + " BasicConv2d-99 [-1, 32, 416, 12] 0\n", + " ZeroPad2d-100 [-1, 32, 418, 14] 0\n", + " Conv2d-101 [-1, 48, 416, 12] 13,824\n", + " BatchNorm2d-102 [-1, 48, 416, 12] 96\n", + " BasicConv2d-103 [-1, 48, 416, 12] 0\n", + " ZeroPad2d-104 [-1, 48, 418, 14] 0\n", + " Conv2d-105 [-1, 96, 416, 12] 41,472\n", + " BatchNorm2d-106 [-1, 96, 416, 12] 192\n", + " BasicConv2d-107 [-1, 96, 416, 12] 0\n", + " BatchNorm2d-108 [-1, 224, 416, 12] 448\n", + "Multi_2D_CNN_block-109 [-1, 224, 416, 12] 0\n", + " MaxPool2d-110 [-1, 224, 138, 12] 0\n", + " ZeroPad2d-111 [-1, 224, 138, 12] 0\n", + " Conv2d-112 [-1, 42, 138, 12] 9,408\n", + " BatchNorm2d-113 [-1, 42, 138, 12] 84\n", + " BasicConv2d-114 [-1, 42, 138, 12] 0\n", + " ZeroPad2d-115 [-1, 224, 138, 12] 0\n", + " Conv2d-116 [-1, 64, 138, 12] 14,336\n", + " BatchNorm2d-117 [-1, 64, 138, 12] 128\n", + " BasicConv2d-118 [-1, 64, 138, 12] 0\n", + " ZeroPad2d-119 [-1, 64, 140, 14] 0\n", + " Conv2d-120 [-1, 128, 138, 12] 73,728\n", + " BatchNorm2d-121 [-1, 128, 138, 12] 256\n", + " BasicConv2d-122 [-1, 128, 138, 12] 0\n", + " ZeroPad2d-123 [-1, 224, 138, 12] 0\n", + " Conv2d-124 [-1, 42, 138, 12] 9,408\n", + " BatchNorm2d-125 [-1, 42, 138, 12] 84\n", + " BasicConv2d-126 [-1, 42, 138, 12] 0\n", + " ZeroPad2d-127 [-1, 42, 140, 14] 0\n", + " Conv2d-128 [-1, 64, 138, 12] 24,192\n", + " BatchNorm2d-129 [-1, 64, 138, 12] 128\n", + " BasicConv2d-130 [-1, 64, 138, 12] 0\n", + " ZeroPad2d-131 [-1, 64, 140, 14] 0\n", + " Conv2d-132 [-1, 128, 138, 12] 73,728\n", + " BatchNorm2d-133 [-1, 128, 138, 12] 256\n", + " BasicConv2d-134 [-1, 128, 138, 12] 0\n", + " BatchNorm2d-135 [-1, 298, 138, 12] 596\n", + "Multi_2D_CNN_block-136 [-1, 298, 138, 12] 0\n", + " ZeroPad2d-137 [-1, 298, 138, 12] 0\n", + " Conv2d-138 [-1, 42, 138, 12] 12,516\n", + " BatchNorm2d-139 [-1, 42, 138, 12] 84\n", + " BasicConv2d-140 [-1, 42, 138, 12] 0\n", + " ZeroPad2d-141 [-1, 298, 138, 12] 0\n", + " Conv2d-142 [-1, 64, 138, 12] 19,072\n", + " BatchNorm2d-143 [-1, 64, 138, 12] 128\n", + " BasicConv2d-144 [-1, 64, 138, 12] 0\n", + " ZeroPad2d-145 [-1, 64, 140, 14] 0\n", + " Conv2d-146 [-1, 128, 138, 12] 73,728\n", + " BatchNorm2d-147 [-1, 128, 138, 12] 256\n", + " BasicConv2d-148 [-1, 128, 138, 12] 0\n", + " ZeroPad2d-149 [-1, 298, 138, 12] 0\n", + " Conv2d-150 [-1, 42, 138, 12] 12,516\n", + " BatchNorm2d-151 [-1, 42, 138, 12] 84\n", + " BasicConv2d-152 [-1, 42, 138, 12] 0\n", + " ZeroPad2d-153 [-1, 42, 140, 14] 0\n", + " Conv2d-154 [-1, 64, 138, 12] 24,192\n", + " BatchNorm2d-155 [-1, 64, 138, 12] 128\n", + " BasicConv2d-156 [-1, 64, 138, 12] 0\n", + " ZeroPad2d-157 [-1, 64, 140, 14] 0\n", + " Conv2d-158 [-1, 128, 138, 12] 73,728\n", + " BatchNorm2d-159 [-1, 128, 138, 12] 256\n", + " BasicConv2d-160 [-1, 128, 138, 12] 0\n", + " BatchNorm2d-161 [-1, 298, 138, 12] 596\n", + "Multi_2D_CNN_block-162 [-1, 298, 138, 12] 0\n", + " MaxPool2d-163 [-1, 298, 69, 12] 0\n", + " ZeroPad2d-164 [-1, 298, 69, 12] 0\n", + " Conv2d-165 [-1, 64, 69, 12] 19,072\n", + " BatchNorm2d-166 [-1, 64, 69, 12] 128\n", + " BasicConv2d-167 [-1, 64, 69, 12] 0\n", + " ZeroPad2d-168 [-1, 298, 69, 12] 0\n", + " Conv2d-169 [-1, 96, 69, 12] 28,608\n", + " BatchNorm2d-170 [-1, 96, 69, 12] 192\n", + " BasicConv2d-171 [-1, 96, 69, 12] 0\n", + " ZeroPad2d-172 [-1, 96, 71, 14] 0\n", + " Conv2d-173 [-1, 192, 69, 12] 165,888\n", + " BatchNorm2d-174 [-1, 192, 69, 12] 384\n", + " BasicConv2d-175 [-1, 192, 69, 12] 0\n", + " ZeroPad2d-176 [-1, 298, 69, 12] 0\n", + " Conv2d-177 [-1, 64, 69, 12] 19,072\n", + " BatchNorm2d-178 [-1, 64, 69, 12] 128\n", + " BasicConv2d-179 [-1, 64, 69, 12] 0\n", + " ZeroPad2d-180 [-1, 64, 71, 14] 0\n", + " Conv2d-181 [-1, 96, 69, 12] 55,296\n", + " BatchNorm2d-182 [-1, 96, 69, 12] 192\n", + " BasicConv2d-183 [-1, 96, 69, 12] 0\n", + " ZeroPad2d-184 [-1, 96, 71, 14] 0\n", + " Conv2d-185 [-1, 192, 69, 12] 165,888\n", + " BatchNorm2d-186 [-1, 192, 69, 12] 384\n", + " BasicConv2d-187 [-1, 192, 69, 12] 0\n", + " BatchNorm2d-188 [-1, 448, 69, 12] 896\n", + "Multi_2D_CNN_block-189 [-1, 448, 69, 12] 0\n", + " ZeroPad2d-190 [-1, 448, 69, 12] 0\n", + " Conv2d-191 [-1, 64, 69, 12] 28,672\n", + " BatchNorm2d-192 [-1, 64, 69, 12] 128\n", + " BasicConv2d-193 [-1, 64, 69, 12] 0\n", + " ZeroPad2d-194 [-1, 448, 69, 12] 0\n", + " Conv2d-195 [-1, 96, 69, 12] 43,008\n", + " BatchNorm2d-196 [-1, 96, 69, 12] 192\n", + " BasicConv2d-197 [-1, 96, 69, 12] 0\n", + " ZeroPad2d-198 [-1, 96, 71, 14] 0\n", + " Conv2d-199 [-1, 192, 69, 12] 165,888\n", + " BatchNorm2d-200 [-1, 192, 69, 12] 384\n", + " BasicConv2d-201 [-1, 192, 69, 12] 0\n", + " ZeroPad2d-202 [-1, 448, 69, 12] 0\n", + " Conv2d-203 [-1, 64, 69, 12] 28,672\n", + " BatchNorm2d-204 [-1, 64, 69, 12] 128\n", + " BasicConv2d-205 [-1, 64, 69, 12] 0\n", + " ZeroPad2d-206 [-1, 64, 71, 14] 0\n", + " Conv2d-207 [-1, 96, 69, 12] 55,296\n", + " BatchNorm2d-208 [-1, 96, 69, 12] 192\n", + " BasicConv2d-209 [-1, 96, 69, 12] 0\n", + " ZeroPad2d-210 [-1, 96, 71, 14] 0\n", + " Conv2d-211 [-1, 192, 69, 12] 165,888\n", + " BatchNorm2d-212 [-1, 192, 69, 12] 384\n", + " BasicConv2d-213 [-1, 192, 69, 12] 0\n", + " BatchNorm2d-214 [-1, 448, 69, 12] 896\n", + "Multi_2D_CNN_block-215 [-1, 448, 69, 12] 0\n", + " ZeroPad2d-216 [-1, 448, 69, 12] 0\n", + " Conv2d-217 [-1, 85, 69, 12] 38,080\n", + " BatchNorm2d-218 [-1, 85, 69, 12] 170\n", + " BasicConv2d-219 [-1, 85, 69, 12] 0\n", + " ZeroPad2d-220 [-1, 448, 69, 12] 0\n", + " Conv2d-221 [-1, 128, 69, 12] 57,344\n", + " BatchNorm2d-222 [-1, 128, 69, 12] 256\n", + " BasicConv2d-223 [-1, 128, 69, 12] 0\n", + " ZeroPad2d-224 [-1, 128, 71, 14] 0\n", + " Conv2d-225 [-1, 256, 69, 12] 294,912\n", + " BatchNorm2d-226 [-1, 256, 69, 12] 512\n", + " BasicConv2d-227 [-1, 256, 69, 12] 0\n", + " ZeroPad2d-228 [-1, 448, 69, 12] 0\n", + " Conv2d-229 [-1, 85, 69, 12] 38,080\n", + " BatchNorm2d-230 [-1, 85, 69, 12] 170\n", + " BasicConv2d-231 [-1, 85, 69, 12] 0\n", + " ZeroPad2d-232 [-1, 85, 71, 14] 0\n", + " Conv2d-233 [-1, 128, 69, 12] 97,920\n", + " BatchNorm2d-234 [-1, 128, 69, 12] 256\n", + " BasicConv2d-235 [-1, 128, 69, 12] 0\n", + " ZeroPad2d-236 [-1, 128, 71, 14] 0\n", + " Conv2d-237 [-1, 256, 69, 12] 294,912\n", + " BatchNorm2d-238 [-1, 256, 69, 12] 512\n", + " BasicConv2d-239 [-1, 256, 69, 12] 0\n", + " BatchNorm2d-240 [-1, 597, 69, 12] 1,194\n", + "Multi_2D_CNN_block-241 [-1, 597, 69, 12] 0\n", + " MaxPool2d-242 [-1, 597, 34, 12] 0\n", + " ZeroPad2d-243 [-1, 597, 34, 12] 0\n", + " Conv2d-244 [-1, 106, 34, 12] 63,282\n", + " BatchNorm2d-245 [-1, 106, 34, 12] 212\n", + " BasicConv2d-246 [-1, 106, 34, 12] 0\n", + " ZeroPad2d-247 [-1, 597, 34, 12] 0\n", + " Conv2d-248 [-1, 160, 34, 12] 95,520\n", + " BatchNorm2d-249 [-1, 160, 34, 12] 320\n", + " BasicConv2d-250 [-1, 160, 34, 12] 0\n", + " ZeroPad2d-251 [-1, 160, 36, 14] 0\n", + " Conv2d-252 [-1, 320, 34, 12] 460,800\n", + " BatchNorm2d-253 [-1, 320, 34, 12] 640\n", + " BasicConv2d-254 [-1, 320, 34, 12] 0\n", + " ZeroPad2d-255 [-1, 597, 34, 12] 0\n", + " Conv2d-256 [-1, 106, 34, 12] 63,282\n", + " BatchNorm2d-257 [-1, 106, 34, 12] 212\n", + " BasicConv2d-258 [-1, 106, 34, 12] 0\n", + " ZeroPad2d-259 [-1, 106, 36, 14] 0\n", + " Conv2d-260 [-1, 160, 34, 12] 152,640\n", + " BatchNorm2d-261 [-1, 160, 34, 12] 320\n", + " BasicConv2d-262 [-1, 160, 34, 12] 0\n", + " ZeroPad2d-263 [-1, 160, 36, 14] 0\n", + " Conv2d-264 [-1, 320, 34, 12] 460,800\n", + " BatchNorm2d-265 [-1, 320, 34, 12] 640\n", + " BasicConv2d-266 [-1, 320, 34, 12] 0\n", + " BatchNorm2d-267 [-1, 746, 34, 12] 1,492\n", + "Multi_2D_CNN_block-268 [-1, 746, 34, 12] 0\n", + " ZeroPad2d-269 [-1, 746, 34, 12] 0\n", + " Conv2d-270 [-1, 128, 34, 12] 95,488\n", + " BatchNorm2d-271 [-1, 128, 34, 12] 256\n", + " BasicConv2d-272 [-1, 128, 34, 12] 0\n", + " ZeroPad2d-273 [-1, 746, 34, 12] 0\n", + " Conv2d-274 [-1, 192, 34, 12] 143,232\n", + " BatchNorm2d-275 [-1, 192, 34, 12] 384\n", + " BasicConv2d-276 [-1, 192, 34, 12] 0\n", + " ZeroPad2d-277 [-1, 192, 36, 14] 0\n", + " Conv2d-278 [-1, 384, 34, 12] 663,552\n", + " BatchNorm2d-279 [-1, 384, 34, 12] 768\n", + " BasicConv2d-280 [-1, 384, 34, 12] 0\n", + " ZeroPad2d-281 [-1, 746, 34, 12] 0\n", + " Conv2d-282 [-1, 128, 34, 12] 95,488\n", + " BatchNorm2d-283 [-1, 128, 34, 12] 256\n", + " BasicConv2d-284 [-1, 128, 34, 12] 0\n", + " ZeroPad2d-285 [-1, 128, 36, 14] 0\n", + " Conv2d-286 [-1, 192, 34, 12] 221,184\n", + " BatchNorm2d-287 [-1, 192, 34, 12] 384\n", + " BasicConv2d-288 [-1, 192, 34, 12] 0\n", + " ZeroPad2d-289 [-1, 192, 36, 14] 0\n", + " Conv2d-290 [-1, 384, 34, 12] 663,552\n", + " BatchNorm2d-291 [-1, 384, 34, 12] 768\n", + " BasicConv2d-292 [-1, 384, 34, 12] 0\n", + " BatchNorm2d-293 [-1, 896, 34, 12] 1,792\n", + "Multi_2D_CNN_block-294 [-1, 896, 34, 12] 0\n", + " ZeroPad2d-295 [-1, 896, 34, 12] 0\n", + " Conv2d-296 [-1, 149, 34, 12] 133,504\n", + " BatchNorm2d-297 [-1, 149, 34, 12] 298\n", + " BasicConv2d-298 [-1, 149, 34, 12] 0\n", + " ZeroPad2d-299 [-1, 896, 34, 12] 0\n", + " Conv2d-300 [-1, 224, 34, 12] 200,704\n", + " BatchNorm2d-301 [-1, 224, 34, 12] 448\n", + " BasicConv2d-302 [-1, 224, 34, 12] 0\n", + " ZeroPad2d-303 [-1, 224, 36, 14] 0\n", + " Conv2d-304 [-1, 448, 34, 12] 903,168\n", + " BatchNorm2d-305 [-1, 448, 34, 12] 896\n", + " BasicConv2d-306 [-1, 448, 34, 12] 0\n", + " ZeroPad2d-307 [-1, 896, 34, 12] 0\n", + " Conv2d-308 [-1, 149, 34, 12] 133,504\n", + " BatchNorm2d-309 [-1, 149, 34, 12] 298\n", + " BasicConv2d-310 [-1, 149, 34, 12] 0\n", + " ZeroPad2d-311 [-1, 149, 36, 14] 0\n", + " Conv2d-312 [-1, 224, 34, 12] 300,384\n", + " BatchNorm2d-313 [-1, 224, 34, 12] 448\n", + " BasicConv2d-314 [-1, 224, 34, 12] 0\n", + " ZeroPad2d-315 [-1, 224, 36, 14] 0\n", + " Conv2d-316 [-1, 448, 34, 12] 903,168\n", + " BatchNorm2d-317 [-1, 448, 34, 12] 896\n", + " BasicConv2d-318 [-1, 448, 34, 12] 0\n", + " BatchNorm2d-319 [-1, 1045, 34, 12] 2,090\n", + "Multi_2D_CNN_block-320 [-1, 1045, 34, 12] 0\n", + " MaxPool2d-321 [-1, 1045, 17, 12] 0\n", + " ZeroPad2d-322 [-1, 1045, 17, 12] 0\n", + " Conv2d-323 [-1, 170, 17, 12] 177,650\n", + " BatchNorm2d-324 [-1, 170, 17, 12] 340\n", + " BasicConv2d-325 [-1, 170, 17, 12] 0\n", + " ZeroPad2d-326 [-1, 1045, 17, 12] 0\n", + " Conv2d-327 [-1, 256, 17, 12] 267,520\n", + " BatchNorm2d-328 [-1, 256, 17, 12] 512\n", + " BasicConv2d-329 [-1, 256, 17, 12] 0\n", + " ZeroPad2d-330 [-1, 256, 19, 14] 0\n", + " Conv2d-331 [-1, 512, 17, 12] 1,179,648\n", + " BatchNorm2d-332 [-1, 512, 17, 12] 1,024\n", + " BasicConv2d-333 [-1, 512, 17, 12] 0\n", + " ZeroPad2d-334 [-1, 1045, 17, 12] 0\n", + " Conv2d-335 [-1, 170, 17, 12] 177,650\n", + " BatchNorm2d-336 [-1, 170, 17, 12] 340\n", + " BasicConv2d-337 [-1, 170, 17, 12] 0\n", + " ZeroPad2d-338 [-1, 170, 19, 14] 0\n", + " Conv2d-339 [-1, 256, 17, 12] 391,680\n", + " BatchNorm2d-340 [-1, 256, 17, 12] 512\n", + " BasicConv2d-341 [-1, 256, 17, 12] 0\n", + " ZeroPad2d-342 [-1, 256, 19, 14] 0\n", + " Conv2d-343 [-1, 512, 17, 12] 1,179,648\n", + " BatchNorm2d-344 [-1, 512, 17, 12] 1,024\n", + " BasicConv2d-345 [-1, 512, 17, 12] 0\n", + " BatchNorm2d-346 [-1, 1194, 17, 12] 2,388\n", + "Multi_2D_CNN_block-347 [-1, 1194, 17, 12] 0\n", + " ZeroPad2d-348 [-1, 1194, 17, 12] 0\n", + " Conv2d-349 [-1, 170, 17, 12] 202,980\n", + " BatchNorm2d-350 [-1, 170, 17, 12] 340\n", + " BasicConv2d-351 [-1, 170, 17, 12] 0\n", + " ZeroPad2d-352 [-1, 1194, 17, 12] 0\n", + " Conv2d-353 [-1, 256, 17, 12] 305,664\n", + " BatchNorm2d-354 [-1, 256, 17, 12] 512\n", + " BasicConv2d-355 [-1, 256, 17, 12] 0\n", + " ZeroPad2d-356 [-1, 256, 19, 14] 0\n", + " Conv2d-357 [-1, 512, 17, 12] 1,179,648\n", + " BatchNorm2d-358 [-1, 512, 17, 12] 1,024\n", + " BasicConv2d-359 [-1, 512, 17, 12] 0\n", + " ZeroPad2d-360 [-1, 1194, 17, 12] 0\n", + " Conv2d-361 [-1, 170, 17, 12] 202,980\n", + " BatchNorm2d-362 [-1, 170, 17, 12] 340\n", + " BasicConv2d-363 [-1, 170, 17, 12] 0\n", + " ZeroPad2d-364 [-1, 170, 19, 14] 0\n", + " Conv2d-365 [-1, 256, 17, 12] 391,680\n", + " BatchNorm2d-366 [-1, 256, 17, 12] 512\n", + " BasicConv2d-367 [-1, 256, 17, 12] 0\n", + " ZeroPad2d-368 [-1, 256, 19, 14] 0\n", + " Conv2d-369 [-1, 512, 17, 12] 1,179,648\n", + " BatchNorm2d-370 [-1, 512, 17, 12] 1,024\n", + " BasicConv2d-371 [-1, 512, 17, 12] 0\n", + " BatchNorm2d-372 [-1, 1194, 17, 12] 2,388\n", + "Multi_2D_CNN_block-373 [-1, 1194, 17, 12] 0\n", + " ZeroPad2d-374 [-1, 1194, 17, 12] 0\n", + " Conv2d-375 [-1, 170, 17, 12] 202,980\n", + " BatchNorm2d-376 [-1, 170, 17, 12] 340\n", + " BasicConv2d-377 [-1, 170, 17, 12] 0\n", + " ZeroPad2d-378 [-1, 1194, 17, 12] 0\n", + " Conv2d-379 [-1, 256, 17, 12] 305,664\n", + " BatchNorm2d-380 [-1, 256, 17, 12] 512\n", + " BasicConv2d-381 [-1, 256, 17, 12] 0\n", + " ZeroPad2d-382 [-1, 256, 19, 14] 0\n", + " Conv2d-383 [-1, 512, 17, 12] 1,179,648\n", + " BatchNorm2d-384 [-1, 512, 17, 12] 1,024\n", + " BasicConv2d-385 [-1, 512, 17, 12] 0\n", + " ZeroPad2d-386 [-1, 1194, 17, 12] 0\n", + " Conv2d-387 [-1, 170, 17, 12] 202,980\n", + " BatchNorm2d-388 [-1, 170, 17, 12] 340\n", + " BasicConv2d-389 [-1, 170, 17, 12] 0\n", + " ZeroPad2d-390 [-1, 170, 19, 14] 0\n", + " Conv2d-391 [-1, 256, 17, 12] 391,680\n", + " BatchNorm2d-392 [-1, 256, 17, 12] 512\n", + " BasicConv2d-393 [-1, 256, 17, 12] 0\n", + " ZeroPad2d-394 [-1, 256, 19, 14] 0\n", + " Conv2d-395 [-1, 512, 17, 12] 1,179,648\n", + " BatchNorm2d-396 [-1, 512, 17, 12] 1,024\n", + " BasicConv2d-397 [-1, 512, 17, 12] 0\n", + " BatchNorm2d-398 [-1, 1194, 17, 12] 2,388\n", + "Multi_2D_CNN_block-399 [-1, 1194, 17, 12] 0\n", + " MaxPool2d-400 [-1, 1194, 8, 12] 0\n", + " ZeroPad2d-401 [-1, 1194, 8, 12] 0\n", + " Conv2d-402 [-1, 256, 8, 12] 305,664\n", + " BatchNorm2d-403 [-1, 256, 8, 12] 512\n", + " BasicConv2d-404 [-1, 256, 8, 12] 0\n", + " ZeroPad2d-405 [-1, 1194, 8, 12] 0\n", + " Conv2d-406 [-1, 384, 8, 12] 458,496\n", + " BatchNorm2d-407 [-1, 384, 8, 12] 768\n", + " BasicConv2d-408 [-1, 384, 8, 12] 0\n", + " ZeroPad2d-409 [-1, 384, 10, 14] 0\n", + " Conv2d-410 [-1, 768, 8, 12] 2,654,208\n", + " BatchNorm2d-411 [-1, 768, 8, 12] 1,536\n", + " BasicConv2d-412 [-1, 768, 8, 12] 0\n", + " ZeroPad2d-413 [-1, 1194, 8, 12] 0\n", + " Conv2d-414 [-1, 256, 8, 12] 305,664\n", + " BatchNorm2d-415 [-1, 256, 8, 12] 512\n", + " BasicConv2d-416 [-1, 256, 8, 12] 0\n", + " ZeroPad2d-417 [-1, 256, 10, 14] 0\n", + " Conv2d-418 [-1, 384, 8, 12] 884,736\n", + " BatchNorm2d-419 [-1, 384, 8, 12] 768\n", + " BasicConv2d-420 [-1, 384, 8, 12] 0\n", + " ZeroPad2d-421 [-1, 384, 10, 14] 0\n", + " Conv2d-422 [-1, 768, 8, 12] 2,654,208\n", + " BatchNorm2d-423 [-1, 768, 8, 12] 1,536\n", + " BasicConv2d-424 [-1, 768, 8, 12] 0\n", + " BatchNorm2d-425 [-1, 1792, 8, 12] 3,584\n", + "Multi_2D_CNN_block-426 [-1, 1792, 8, 12] 0\n", + " ZeroPad2d-427 [-1, 1792, 8, 12] 0\n", + " Conv2d-428 [-1, 298, 8, 12] 534,016\n", + " BatchNorm2d-429 [-1, 298, 8, 12] 596\n", + " BasicConv2d-430 [-1, 298, 8, 12] 0\n", + " ZeroPad2d-431 [-1, 1792, 8, 12] 0\n", + " Conv2d-432 [-1, 448, 8, 12] 802,816\n", + " BatchNorm2d-433 [-1, 448, 8, 12] 896\n", + " BasicConv2d-434 [-1, 448, 8, 12] 0\n", + " ZeroPad2d-435 [-1, 448, 10, 14] 0\n", + " Conv2d-436 [-1, 896, 8, 12] 3,612,672\n", + " BatchNorm2d-437 [-1, 896, 8, 12] 1,792\n", + " BasicConv2d-438 [-1, 896, 8, 12] 0\n", + " ZeroPad2d-439 [-1, 1792, 8, 12] 0\n", + " Conv2d-440 [-1, 298, 8, 12] 534,016\n", + " BatchNorm2d-441 [-1, 298, 8, 12] 596\n", + " BasicConv2d-442 [-1, 298, 8, 12] 0\n", + " ZeroPad2d-443 [-1, 298, 10, 14] 0\n", + " Conv2d-444 [-1, 448, 8, 12] 1,201,536\n", + " BatchNorm2d-445 [-1, 448, 8, 12] 896\n", + " BasicConv2d-446 [-1, 448, 8, 12] 0\n", + " ZeroPad2d-447 [-1, 448, 10, 14] 0\n", + " Conv2d-448 [-1, 896, 8, 12] 3,612,672\n", + " BatchNorm2d-449 [-1, 896, 8, 12] 1,792\n", + " BasicConv2d-450 [-1, 896, 8, 12] 0\n", + " BatchNorm2d-451 [-1, 2090, 8, 12] 4,180\n", + "Multi_2D_CNN_block-452 [-1, 2090, 8, 12] 0\n", + " ZeroPad2d-453 [-1, 2090, 8, 12] 0\n", + " Conv2d-454 [-1, 341, 8, 12] 712,690\n", + " BatchNorm2d-455 [-1, 341, 8, 12] 682\n", + " BasicConv2d-456 [-1, 341, 8, 12] 0\n", + " ZeroPad2d-457 [-1, 2090, 8, 12] 0\n", + " Conv2d-458 [-1, 512, 8, 12] 1,070,080\n", + " BatchNorm2d-459 [-1, 512, 8, 12] 1,024\n", + " BasicConv2d-460 [-1, 512, 8, 12] 0\n", + " ZeroPad2d-461 [-1, 512, 10, 14] 0\n", + " Conv2d-462 [-1, 1024, 8, 12] 4,718,592\n", + " BatchNorm2d-463 [-1, 1024, 8, 12] 2,048\n", + " BasicConv2d-464 [-1, 1024, 8, 12] 0\n", + " ZeroPad2d-465 [-1, 2090, 8, 12] 0\n", + " Conv2d-466 [-1, 341, 8, 12] 712,690\n", + " BatchNorm2d-467 [-1, 341, 8, 12] 682\n", + " BasicConv2d-468 [-1, 341, 8, 12] 0\n", + " ZeroPad2d-469 [-1, 341, 10, 14] 0\n", + " Conv2d-470 [-1, 512, 8, 12] 1,571,328\n", + " BatchNorm2d-471 [-1, 512, 8, 12] 1,024\n", + " BasicConv2d-472 [-1, 512, 8, 12] 0\n", + " ZeroPad2d-473 [-1, 512, 10, 14] 0\n", + " Conv2d-474 [-1, 1024, 8, 12] 4,718,592\n", + " BatchNorm2d-475 [-1, 1024, 8, 12] 2,048\n", + " BasicConv2d-476 [-1, 1024, 8, 12] 0\n", + " BatchNorm2d-477 [-1, 2389, 8, 12] 4,778\n", + "Multi_2D_CNN_block-478 [-1, 2389, 8, 12] 0\n", + "AdaptiveAvgPool2d-479 [-1, 2389, 1, 1] 0\n", + " Flatten-480 [-1, 2389] 0\n", + " Dropout-481 [-1, 2389] 0\n", + " Linear-482 [-1, 1] 2,390\n", + "================================================================\n", + "Total params: 49,719,798\n", + "Trainable params: 49,719,798\n", + "Non-trainable params: 0\n", + "----------------------------------------------------------------\n", + "Input size (MB): 0.11\n", + "Forward/backward pass size (MB): 884.62\n", + "Params size (MB): 189.67\n", + "Estimated Total Size (MB): 1074.40\n", + "----------------------------------------------------------------\n" + ] + } + ], + "source": [ + "summary(model, input_size=(1, 2500, 12))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "class Inception_Tabular_7_1(nn.Module):\n", + "\n", + " def __init__(self, initial_kernel_num=64):\n", + " super(Inception_Tabular_7_1, self).__init__()\n", + "\n", + " multi_2d_cnn = Multi_2D_CNN_block\n", + " conv_block = BasicConv2d\n", + "\n", + " self.conv_1 = conv_block(1, 64, kernel_size=(7, 1), stride=(2, 1))\n", + "\n", + " self.multi_2d_cnn_1a = nn.Sequential(\n", + " multi_2d_cnn(in_channels=64, num_kernel=initial_kernel_num),\n", + " multi_2d_cnn(in_channels=149, num_kernel=initial_kernel_num),\n", + " nn.MaxPool2d(kernel_size=(3, 1))\n", + " )\n", + "\n", + " self.multi_2d_cnn_1b = nn.Sequential(\n", + " multi_2d_cnn(in_channels=149, num_kernel=initial_kernel_num * 1.5),\n", + " multi_2d_cnn(in_channels=224, num_kernel=initial_kernel_num * 1.5),\n", + " nn.MaxPool2d(kernel_size=(3, 1))\n", + " )\n", + " self.multi_2d_cnn_1c = nn.Sequential(\n", + " multi_2d_cnn(in_channels=224, num_kernel=initial_kernel_num * 2),\n", + " multi_2d_cnn(in_channels=298, num_kernel=initial_kernel_num * 2),\n", + " nn.MaxPool2d(kernel_size=(2, 1))\n", + " )\n", + "\n", + " self.multi_2d_cnn_2a = nn.Sequential(\n", + " multi_2d_cnn(in_channels=298, num_kernel=initial_kernel_num * 3),\n", + " multi_2d_cnn(in_channels=448, num_kernel=initial_kernel_num * 3),\n", + " multi_2d_cnn(in_channels=448, num_kernel=initial_kernel_num * 4),\n", + " nn.MaxPool2d(kernel_size=(2, 1))\n", + " )\n", + " self.multi_2d_cnn_2b = nn.Sequential(\n", + " multi_2d_cnn(in_channels=597, num_kernel=initial_kernel_num * 5),\n", + " multi_2d_cnn(in_channels=746, num_kernel=initial_kernel_num * 6),\n", + " multi_2d_cnn(in_channels=896, num_kernel=initial_kernel_num * 7),\n", + " nn.MaxPool2d(kernel_size=(2, 1))\n", + " )\n", + " self.multi_2d_cnn_2c = nn.Sequential(\n", + " multi_2d_cnn(in_channels=1045, num_kernel=initial_kernel_num * 8),\n", + " multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 8),\n", + " multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 8),\n", + " nn.MaxPool2d(kernel_size=(2, 1))\n", + " )\n", + " self.multi_2d_cnn_2d = nn.Sequential(\n", + " multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 12),\n", + " multi_2d_cnn(in_channels=1792, num_kernel=initial_kernel_num * 14),\n", + " multi_2d_cnn(in_channels=2090, num_kernel=initial_kernel_num * 16),\n", + " )\n", + " self.output = nn.Sequential(\n", + " nn.AdaptiveAvgPool2d((1, 1)),\n", + " nn.Flatten(),\n", + " nn.Dropout(0.5),\n", + " nn.Linear(2389, 1),\n", + " # nn.Sigmoid()\n", + " )\n", + "\n", + " def forward(self, x):\n", + " x = self.conv_1(x)\n", + " # N x 1250 x 12 x 64 tensor\n", + " x = self.multi_2d_cnn_1a(x)\n", + " # N x 416 x 12 x 149 tensor\n", + " x = self.multi_2d_cnn_1b(x)\n", + " # N x 138 x 12 x 224 tensor\n", + " x = self.multi_2d_cnn_1c(x)\n", + " # N x 69 x 12 x 298\n", + " x = self.multi_2d_cnn_2a(x)\n", + "\n", + " x = self.multi_2d_cnn_2b(x)\n", + "\n", + " x = self.multi_2d_cnn_2c(x)\n", + "\n", + " x = self.multi_2d_cnn_2d(x)\n", + "\n", + " x = self.output(x)\n", + "\n", + " return x\n", + "\n", + "class Inception_Tabular_15_3(nn.Module):\n", + "\n", + " def __init__(self, initial_kernel_num=64):\n", + " super(Inception_Tabular_15_3, self).__init__()\n", + "\n", + " multi_2d_cnn = Multi_2D_CNN_block\n", + " conv_block = BasicConv2d\n", + "\n", + " self.conv_1 = conv_block(1, 64, kernel_size=(15, 3), stride=(2, 1))\n", + "\n", + " self.multi_2d_cnn_1a = nn.Sequential(\n", + " multi_2d_cnn(in_channels=64, num_kernel=initial_kernel_num),\n", + " multi_2d_cnn(in_channels=149, num_kernel=initial_kernel_num),\n", + " nn.MaxPool2d(kernel_size=(3, 1))\n", + " )\n", + "\n", + " self.multi_2d_cnn_1b = nn.Sequential(\n", + " multi_2d_cnn(in_channels=149, num_kernel=initial_kernel_num * 1.5),\n", + " multi_2d_cnn(in_channels=224, num_kernel=initial_kernel_num * 1.5),\n", + " nn.MaxPool2d(kernel_size=(3, 1))\n", + " )\n", + " self.multi_2d_cnn_1c = nn.Sequential(\n", + " multi_2d_cnn(in_channels=224, num_kernel=initial_kernel_num * 2),\n", + " multi_2d_cnn(in_channels=298, num_kernel=initial_kernel_num * 2),\n", + " nn.MaxPool2d(kernel_size=(2, 1))\n", + " )\n", + "\n", + " self.multi_2d_cnn_2a = nn.Sequential(\n", + " multi_2d_cnn(in_channels=298, num_kernel=initial_kernel_num * 3),\n", + " multi_2d_cnn(in_channels=448, num_kernel=initial_kernel_num * 3),\n", + " multi_2d_cnn(in_channels=448, num_kernel=initial_kernel_num * 4),\n", + " nn.MaxPool2d(kernel_size=(2, 1))\n", + " )\n", + " self.multi_2d_cnn_2b = nn.Sequential(\n", + " multi_2d_cnn(in_channels=597, num_kernel=initial_kernel_num * 5),\n", + " multi_2d_cnn(in_channels=746, num_kernel=initial_kernel_num * 6),\n", + " multi_2d_cnn(in_channels=896, num_kernel=initial_kernel_num * 7),\n", + " nn.MaxPool2d(kernel_size=(2, 1))\n", + " )\n", + " self.multi_2d_cnn_2c = nn.Sequential(\n", + " multi_2d_cnn(in_channels=1045, num_kernel=initial_kernel_num * 8),\n", + " multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 8),\n", + " multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 8),\n", + " nn.MaxPool2d(kernel_size=(2, 1))\n", + " )\n", + " self.multi_2d_cnn_2d = nn.Sequential(\n", + " multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 12),\n", + " multi_2d_cnn(in_channels=1792, num_kernel=initial_kernel_num * 14),\n", + " multi_2d_cnn(in_channels=2090, num_kernel=initial_kernel_num * 16),\n", + " )\n", + " self.output = nn.Sequential(\n", + " nn.AdaptiveAvgPool2d((1, 1)),\n", + " nn.Flatten(),\n", + " nn.Dropout(0.5),\n", + " nn.Linear(2389, 1),\n", + " # nn.Sigmoid()\n", + " )\n", + "\n", + " def forward(self, x):\n", + " x = self.conv_1(x)\n", + " # N x 1250 x 12 x 64 tensor\n", + " x = self.multi_2d_cnn_1a(x)\n", + " # N x 416 x 12 x 149 tensor\n", + " x = self.multi_2d_cnn_1b(x)\n", + " # N x 138 x 12 x 224 tensor\n", + " x = self.multi_2d_cnn_1c(x)\n", + " # N x 69 x 12 x 298\n", + " x = self.multi_2d_cnn_2a(x)\n", + "\n", + " x = self.multi_2d_cnn_2b(x)\n", + "\n", + " x = self.multi_2d_cnn_2c(x)\n", + "\n", + " x = self.multi_2d_cnn_2d(x)\n", + "\n", + " x = self.output(x)\n", + "\n", + " return x\n", + "\n", + "\n", + "class Inception_Tabular_15_1(nn.Module):\n", + "\n", + " def __init__(self, initial_kernel_num=64):\n", + " super(Inception_Tabular_15_1, self).__init__()\n", + "\n", + " multi_2d_cnn = Multi_2D_CNN_block\n", + " conv_block = BasicConv2d\n", + "\n", + " self.conv_1 = conv_block(1, 64, kernel_size=(15, 1), stride=(2, 1))\n", + "\n", + " self.multi_2d_cnn_1a = nn.Sequential(\n", + " multi_2d_cnn(in_channels=64, num_kernel=initial_kernel_num),\n", + " multi_2d_cnn(in_channels=149, num_kernel=initial_kernel_num),\n", + " nn.MaxPool2d(kernel_size=(3, 1))\n", + " )\n", + "\n", + " self.multi_2d_cnn_1b = nn.Sequential(\n", + " multi_2d_cnn(in_channels=149, num_kernel=initial_kernel_num * 1.5),\n", + " multi_2d_cnn(in_channels=224, num_kernel=initial_kernel_num * 1.5),\n", + " nn.MaxPool2d(kernel_size=(3, 1))\n", + " )\n", + " self.multi_2d_cnn_1c = nn.Sequential(\n", + " multi_2d_cnn(in_channels=224, num_kernel=initial_kernel_num * 2),\n", + " multi_2d_cnn(in_channels=298, num_kernel=initial_kernel_num * 2),\n", + " nn.MaxPool2d(kernel_size=(2, 1))\n", + " )\n", + "\n", + " self.multi_2d_cnn_2a = nn.Sequential(\n", + " multi_2d_cnn(in_channels=298, num_kernel=initial_kernel_num * 3),\n", + " multi_2d_cnn(in_channels=448, num_kernel=initial_kernel_num * 3),\n", + " multi_2d_cnn(in_channels=448, num_kernel=initial_kernel_num * 4),\n", + " nn.MaxPool2d(kernel_size=(2, 1))\n", + " )\n", + " self.multi_2d_cnn_2b = nn.Sequential(\n", + " multi_2d_cnn(in_channels=597, num_kernel=initial_kernel_num * 5),\n", + " multi_2d_cnn(in_channels=746, num_kernel=initial_kernel_num * 6),\n", + " multi_2d_cnn(in_channels=896, num_kernel=initial_kernel_num * 7),\n", + " nn.MaxPool2d(kernel_size=(2, 1))\n", + " )\n", + " self.multi_2d_cnn_2c = nn.Sequential(\n", + " multi_2d_cnn(in_channels=1045, num_kernel=initial_kernel_num * 8),\n", + " multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 8),\n", + " multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 8),\n", + " nn.MaxPool2d(kernel_size=(2, 1))\n", + " )\n", + " self.multi_2d_cnn_2d = nn.Sequential(\n", + " multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 12),\n", + " multi_2d_cnn(in_channels=1792, num_kernel=initial_kernel_num * 14),\n", + " multi_2d_cnn(in_channels=2090, num_kernel=initial_kernel_num * 16),\n", + " )\n", + " self.output = nn.Sequential(\n", + " nn.AdaptiveAvgPool2d((1, 1)),\n", + " nn.Flatten(),\n", + " nn.Dropout(0.5),\n", + " nn.Linear(2389, 1),\n", + " # nn.Sigmoid()\n", + " )\n", + "\n", + " def forward(self, x):\n", + " x = self.conv_1(x)\n", + " # N x 1250 x 12 x 64 tensor\n", + " x = self.multi_2d_cnn_1a(x)\n", + " # N x 416 x 12 x 149 tensor\n", + " x = self.multi_2d_cnn_1b(x)\n", + " # N x 138 x 12 x 224 tensor\n", + " x = self.multi_2d_cnn_1c(x)\n", + " # N x 69 x 12 x 298\n", + " x = self.multi_2d_cnn_2a(x)\n", + "\n", + " x = self.multi_2d_cnn_2b(x)\n", + "\n", + " x = self.multi_2d_cnn_2c(x)\n", + "\n", + " x = self.multi_2d_cnn_2d(x)\n", + "\n", + " x = self.output(x)\n", + "\n", + " return x\n", + "\n", + "class Inception_Tabular_5_Linears(nn.Module):\n", + "\n", + " def __init__(self, initial_kernel_num=64):\n", + " super(Inception_Tabular_5_Linears, self).__init__()\n", + "\n", + " multi_2d_cnn = Multi_2D_CNN_block\n", + " conv_block = BasicConv2d\n", + "\n", + " self.conv_1 = conv_block(1, 64, kernel_size=(7, 3), stride=(2, 1))\n", + "\n", + " self.multi_2d_cnn_1a = nn.Sequential(\n", + " multi_2d_cnn(in_channels=64, num_kernel=initial_kernel_num),\n", + " multi_2d_cnn(in_channels=149, num_kernel=initial_kernel_num),\n", + " nn.MaxPool2d(kernel_size=(3, 1))\n", + " )\n", + "\n", + " self.multi_2d_cnn_1b = nn.Sequential(\n", + " multi_2d_cnn(in_channels=149, num_kernel=initial_kernel_num * 1.5),\n", + " multi_2d_cnn(in_channels=224, num_kernel=initial_kernel_num * 1.5),\n", + " nn.MaxPool2d(kernel_size=(3, 1))\n", + " )\n", + " self.multi_2d_cnn_1c = nn.Sequential(\n", + " multi_2d_cnn(in_channels=224, num_kernel=initial_kernel_num * 2),\n", + " multi_2d_cnn(in_channels=298, num_kernel=initial_kernel_num * 2),\n", + " nn.MaxPool2d(kernel_size=(2, 1))\n", + " )\n", + "\n", + " self.multi_2d_cnn_2a = nn.Sequential(\n", + " multi_2d_cnn(in_channels=298, num_kernel=initial_kernel_num * 3),\n", + " multi_2d_cnn(in_channels=448, num_kernel=initial_kernel_num * 3),\n", + " multi_2d_cnn(in_channels=448, num_kernel=initial_kernel_num * 4),\n", + " nn.MaxPool2d(kernel_size=(2, 1))\n", + " )\n", + " self.multi_2d_cnn_2b = nn.Sequential(\n", + " multi_2d_cnn(in_channels=597, num_kernel=initial_kernel_num * 5),\n", + " multi_2d_cnn(in_channels=746, num_kernel=initial_kernel_num * 6),\n", + " multi_2d_cnn(in_channels=896, num_kernel=initial_kernel_num * 7),\n", + " nn.MaxPool2d(kernel_size=(2, 1))\n", + " )\n", + " self.multi_2d_cnn_2c = nn.Sequential(\n", + " multi_2d_cnn(in_channels=1045, num_kernel=initial_kernel_num * 8),\n", + " multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 8),\n", + " multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 8),\n", + " nn.MaxPool2d(kernel_size=(2, 1))\n", + " )\n", + " self.multi_2d_cnn_2d = nn.Sequential(\n", + " multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 12),\n", + " multi_2d_cnn(in_channels=1792, num_kernel=initial_kernel_num * 14),\n", + " multi_2d_cnn(in_channels=2090, num_kernel=initial_kernel_num * 16),\n", + " )\n", + " self.output = nn.Sequential(\n", + " nn.AdaptiveAvgPool2d((1, 1)),\n", + " nn.Flatten(),\n", + " nn.Dropout(0.5),\n", + " nn.Linear(2389, 1200),\n", + " nn.Linear(1200, 400),\n", + " nn.Linear(400, 100),\n", + " nn.Linear(100, 20),\n", + " nn.Linear(20, 1),\n", + " # nn.Sigmoid()\n", + " )\n", + "\n", + " def forward(self, x):\n", + " x = self.conv_1(x)\n", + " # N x 1250 x 12 x 64 tensor\n", + " x = self.multi_2d_cnn_1a(x)\n", + " # N x 416 x 12 x 149 tensor\n", + " x = self.multi_2d_cnn_1b(x)\n", + " # N x 138 x 12 x 224 tensor\n", + " x = self.multi_2d_cnn_1c(x)\n", + " # N x 69 x 12 x 298\n", + " x = self.multi_2d_cnn_2a(x)\n", + "\n", + " x = self.multi_2d_cnn_2b(x)\n", + "\n", + " x = self.multi_2d_cnn_2c(x)\n", + "\n", + " x = self.multi_2d_cnn_2d(x)\n", + "\n", + " x = self.output(x)\n", + "\n", + " return x\n", + "\n", + "class Inception_Tabular_5_Linears_BatchNorm(nn.Module):\n", + "\n", + " def __init__(self, initial_kernel_num=64):\n", + " super(Inception_Tabular_5_Linears, self).__init__()\n", + "\n", + " multi_2d_cnn = Multi_2D_CNN_block\n", + " conv_block = BasicConv2d\n", + "\n", + " self.conv_1 = conv_block(1, 64, kernel_size=(7, 3), stride=(2, 1))\n", + "\n", + " self.multi_2d_cnn_1a = nn.Sequential(\n", + " multi_2d_cnn(in_channels=64, num_kernel=initial_kernel_num),\n", + " multi_2d_cnn(in_channels=149, num_kernel=initial_kernel_num),\n", + " nn.MaxPool2d(kernel_size=(3, 1))\n", + " )\n", + "\n", + " self.multi_2d_cnn_1b = nn.Sequential(\n", + " multi_2d_cnn(in_channels=149, num_kernel=initial_kernel_num * 1.5),\n", + " multi_2d_cnn(in_channels=224, num_kernel=initial_kernel_num * 1.5),\n", + " nn.MaxPool2d(kernel_size=(3, 1))\n", + " )\n", + " self.multi_2d_cnn_1c = nn.Sequential(\n", + " multi_2d_cnn(in_channels=224, num_kernel=initial_kernel_num * 2),\n", + " multi_2d_cnn(in_channels=298, num_kernel=initial_kernel_num * 2),\n", + " nn.MaxPool2d(kernel_size=(2, 1))\n", + " )\n", + "\n", + " self.multi_2d_cnn_2a = nn.Sequential(\n", + " multi_2d_cnn(in_channels=298, num_kernel=initial_kernel_num * 3),\n", + " multi_2d_cnn(in_channels=448, num_kernel=initial_kernel_num * 3),\n", + " multi_2d_cnn(in_channels=448, num_kernel=initial_kernel_num * 4),\n", + " nn.MaxPool2d(kernel_size=(2, 1))\n", + " )\n", + " self.multi_2d_cnn_2b = nn.Sequential(\n", + " multi_2d_cnn(in_channels=597, num_kernel=initial_kernel_num * 5),\n", + " multi_2d_cnn(in_channels=746, num_kernel=initial_kernel_num * 6),\n", + " multi_2d_cnn(in_channels=896, num_kernel=initial_kernel_num * 7),\n", + " nn.MaxPool2d(kernel_size=(2, 1))\n", + " )\n", + " self.multi_2d_cnn_2c = nn.Sequential(\n", + " multi_2d_cnn(in_channels=1045, num_kernel=initial_kernel_num * 8),\n", + " multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 8),\n", + " multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 8),\n", + " nn.MaxPool2d(kernel_size=(2, 1))\n", + " )\n", + " self.multi_2d_cnn_2d = nn.Sequential(\n", + " multi_2d_cnn(in_channels=1194, num_kernel=initial_kernel_num * 12),\n", + " multi_2d_cnn(in_channels=1792, num_kernel=initial_kernel_num * 14),\n", + " multi_2d_cnn(in_channels=2090, num_kernel=initial_kernel_num * 16),\n", + " )\n", + " self.output = nn.Sequential(\n", + " nn.AdaptiveAvgPool2d((1, 1)),\n", + " nn.Flatten(),\n", + " nn.Dropout(0.5),\n", + " nn.Linear(2389, 1200),\n", + " nn.BatchNorm1d(1200),\n", + " nn.Linear(1200, 400),\n", + " nn.BatchNorm1d(400),\n", + " nn.Linear(400, 100),\n", + " nn.BatchNorm1d(100),\n", + " nn.Linear(100, 20),\n", + " nn.BatchNorm1d(20),\n", + " nn.Linear(20, 1)\n", + " # nn.Sigmoid()\n", + " )\n", + "\n", + " def forward(self, x):\n", + " x = self.conv_1(x)\n", + " # N x 1250 x 12 x 64 tensor\n", + " x = self.multi_2d_cnn_1a(x)\n", + " # N x 416 x 12 x 149 tensor\n", + " x = self.multi_2d_cnn_1b(x)\n", + " # N x 138 x 12 x 224 tensor\n", + " x = self.multi_2d_cnn_1c(x)\n", + " # N x 69 x 12 x 298\n", + " x = self.multi_2d_cnn_2a(x)\n", + "\n", + " x = self.multi_2d_cnn_2b(x)\n", + "\n", + " x = self.multi_2d_cnn_2c(x)\n", + "\n", + " x = self.multi_2d_cnn_2d(x)\n", + "\n", + " x = self.output(x)\n", + "\n", + " return x\n", + " \n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "model = Inception_Tabular_15_3()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------------------------------------\n", + " Layer (type) Output Shape Param #\n", + "================================================================\n", + " ZeroPad2d-1 [-1, 1, 2514, 14] 0\n", + " Conv2d-2 [-1, 64, 1250, 12] 2,880\n", + " BatchNorm2d-3 [-1, 64, 1250, 12] 128\n", + " BasicConv2d-4 [-1, 64, 1250, 12] 0\n", + " ZeroPad2d-5 [-1, 64, 1250, 12] 0\n", + " Conv2d-6 [-1, 21, 1250, 12] 1,344\n", + " BatchNorm2d-7 [-1, 21, 1250, 12] 42\n", + " BasicConv2d-8 [-1, 21, 1250, 12] 0\n", + " ZeroPad2d-9 [-1, 64, 1250, 12] 0\n", + " Conv2d-10 [-1, 32, 1250, 12] 2,048\n", + " BatchNorm2d-11 [-1, 32, 1250, 12] 64\n", + " BasicConv2d-12 [-1, 32, 1250, 12] 0\n", + " ZeroPad2d-13 [-1, 32, 1252, 14] 0\n", + " Conv2d-14 [-1, 64, 1250, 12] 18,432\n", + " BatchNorm2d-15 [-1, 64, 1250, 12] 128\n", + " BasicConv2d-16 [-1, 64, 1250, 12] 0\n", + " ZeroPad2d-17 [-1, 64, 1250, 12] 0\n", + " Conv2d-18 [-1, 21, 1250, 12] 1,344\n", + " BatchNorm2d-19 [-1, 21, 1250, 12] 42\n", + " BasicConv2d-20 [-1, 21, 1250, 12] 0\n", + " ZeroPad2d-21 [-1, 21, 1252, 14] 0\n", + " Conv2d-22 [-1, 32, 1250, 12] 6,048\n", + " BatchNorm2d-23 [-1, 32, 1250, 12] 64\n", + " BasicConv2d-24 [-1, 32, 1250, 12] 0\n", + " ZeroPad2d-25 [-1, 32, 1252, 14] 0\n", + " Conv2d-26 [-1, 64, 1250, 12] 18,432\n", + " BatchNorm2d-27 [-1, 64, 1250, 12] 128\n", + " BasicConv2d-28 [-1, 64, 1250, 12] 0\n", + " BatchNorm2d-29 [-1, 149, 1250, 12] 298\n", + "Multi_2D_CNN_block-30 [-1, 149, 1250, 12] 0\n", + " ZeroPad2d-31 [-1, 149, 1250, 12] 0\n", + " Conv2d-32 [-1, 21, 1250, 12] 3,129\n", + " BatchNorm2d-33 [-1, 21, 1250, 12] 42\n", + " BasicConv2d-34 [-1, 21, 1250, 12] 0\n", + " ZeroPad2d-35 [-1, 149, 1250, 12] 0\n", + " Conv2d-36 [-1, 32, 1250, 12] 4,768\n", + " BatchNorm2d-37 [-1, 32, 1250, 12] 64\n", + " BasicConv2d-38 [-1, 32, 1250, 12] 0\n", + " ZeroPad2d-39 [-1, 32, 1252, 14] 0\n", + " Conv2d-40 [-1, 64, 1250, 12] 18,432\n", + " BatchNorm2d-41 [-1, 64, 1250, 12] 128\n", + " BasicConv2d-42 [-1, 64, 1250, 12] 0\n", + " ZeroPad2d-43 [-1, 149, 1250, 12] 0\n", + " Conv2d-44 [-1, 21, 1250, 12] 3,129\n", + " BatchNorm2d-45 [-1, 21, 1250, 12] 42\n", + " BasicConv2d-46 [-1, 21, 1250, 12] 0\n", + " ZeroPad2d-47 [-1, 21, 1252, 14] 0\n", + " Conv2d-48 [-1, 32, 1250, 12] 6,048\n", + " BatchNorm2d-49 [-1, 32, 1250, 12] 64\n", + " BasicConv2d-50 [-1, 32, 1250, 12] 0\n", + " ZeroPad2d-51 [-1, 32, 1252, 14] 0\n", + " Conv2d-52 [-1, 64, 1250, 12] 18,432\n", + " BatchNorm2d-53 [-1, 64, 1250, 12] 128\n", + " BasicConv2d-54 [-1, 64, 1250, 12] 0\n", + " BatchNorm2d-55 [-1, 149, 1250, 12] 298\n", + "Multi_2D_CNN_block-56 [-1, 149, 1250, 12] 0\n", + " MaxPool2d-57 [-1, 149, 416, 12] 0\n", + " ZeroPad2d-58 [-1, 149, 416, 12] 0\n", + " Conv2d-59 [-1, 32, 416, 12] 4,768\n", + " BatchNorm2d-60 [-1, 32, 416, 12] 64\n", + " BasicConv2d-61 [-1, 32, 416, 12] 0\n", + " ZeroPad2d-62 [-1, 149, 416, 12] 0\n", + " Conv2d-63 [-1, 48, 416, 12] 7,152\n", + " BatchNorm2d-64 [-1, 48, 416, 12] 96\n", + " BasicConv2d-65 [-1, 48, 416, 12] 0\n", + " ZeroPad2d-66 [-1, 48, 418, 14] 0\n", + " Conv2d-67 [-1, 96, 416, 12] 41,472\n", + " BatchNorm2d-68 [-1, 96, 416, 12] 192\n", + " BasicConv2d-69 [-1, 96, 416, 12] 0\n", + " ZeroPad2d-70 [-1, 149, 416, 12] 0\n", + " Conv2d-71 [-1, 32, 416, 12] 4,768\n", + " BatchNorm2d-72 [-1, 32, 416, 12] 64\n", + " BasicConv2d-73 [-1, 32, 416, 12] 0\n", + " ZeroPad2d-74 [-1, 32, 418, 14] 0\n", + " Conv2d-75 [-1, 48, 416, 12] 13,824\n", + " BatchNorm2d-76 [-1, 48, 416, 12] 96\n", + " BasicConv2d-77 [-1, 48, 416, 12] 0\n", + " ZeroPad2d-78 [-1, 48, 418, 14] 0\n", + " Conv2d-79 [-1, 96, 416, 12] 41,472\n", + " BatchNorm2d-80 [-1, 96, 416, 12] 192\n", + " BasicConv2d-81 [-1, 96, 416, 12] 0\n", + " BatchNorm2d-82 [-1, 224, 416, 12] 448\n", + "Multi_2D_CNN_block-83 [-1, 224, 416, 12] 0\n", + " ZeroPad2d-84 [-1, 224, 416, 12] 0\n", + " Conv2d-85 [-1, 32, 416, 12] 7,168\n", + " BatchNorm2d-86 [-1, 32, 416, 12] 64\n", + " BasicConv2d-87 [-1, 32, 416, 12] 0\n", + " ZeroPad2d-88 [-1, 224, 416, 12] 0\n", + " Conv2d-89 [-1, 48, 416, 12] 10,752\n", + " BatchNorm2d-90 [-1, 48, 416, 12] 96\n", + " BasicConv2d-91 [-1, 48, 416, 12] 0\n", + " ZeroPad2d-92 [-1, 48, 418, 14] 0\n", + " Conv2d-93 [-1, 96, 416, 12] 41,472\n", + " BatchNorm2d-94 [-1, 96, 416, 12] 192\n", + " BasicConv2d-95 [-1, 96, 416, 12] 0\n", + " ZeroPad2d-96 [-1, 224, 416, 12] 0\n", + " Conv2d-97 [-1, 32, 416, 12] 7,168\n", + " BatchNorm2d-98 [-1, 32, 416, 12] 64\n", + " BasicConv2d-99 [-1, 32, 416, 12] 0\n", + " ZeroPad2d-100 [-1, 32, 418, 14] 0\n", + " Conv2d-101 [-1, 48, 416, 12] 13,824\n", + " BatchNorm2d-102 [-1, 48, 416, 12] 96\n", + " BasicConv2d-103 [-1, 48, 416, 12] 0\n", + " ZeroPad2d-104 [-1, 48, 418, 14] 0\n", + " Conv2d-105 [-1, 96, 416, 12] 41,472\n", + " BatchNorm2d-106 [-1, 96, 416, 12] 192\n", + " BasicConv2d-107 [-1, 96, 416, 12] 0\n", + " BatchNorm2d-108 [-1, 224, 416, 12] 448\n", + "Multi_2D_CNN_block-109 [-1, 224, 416, 12] 0\n", + " MaxPool2d-110 [-1, 224, 138, 12] 0\n", + " ZeroPad2d-111 [-1, 224, 138, 12] 0\n", + " Conv2d-112 [-1, 42, 138, 12] 9,408\n", + " BatchNorm2d-113 [-1, 42, 138, 12] 84\n", + " BasicConv2d-114 [-1, 42, 138, 12] 0\n", + " ZeroPad2d-115 [-1, 224, 138, 12] 0\n", + " Conv2d-116 [-1, 64, 138, 12] 14,336\n", + " BatchNorm2d-117 [-1, 64, 138, 12] 128\n", + " BasicConv2d-118 [-1, 64, 138, 12] 0\n", + " ZeroPad2d-119 [-1, 64, 140, 14] 0\n", + " Conv2d-120 [-1, 128, 138, 12] 73,728\n", + " BatchNorm2d-121 [-1, 128, 138, 12] 256\n", + " BasicConv2d-122 [-1, 128, 138, 12] 0\n", + " ZeroPad2d-123 [-1, 224, 138, 12] 0\n", + " Conv2d-124 [-1, 42, 138, 12] 9,408\n", + " BatchNorm2d-125 [-1, 42, 138, 12] 84\n", + " BasicConv2d-126 [-1, 42, 138, 12] 0\n", + " ZeroPad2d-127 [-1, 42, 140, 14] 0\n", + " Conv2d-128 [-1, 64, 138, 12] 24,192\n", + " BatchNorm2d-129 [-1, 64, 138, 12] 128\n", + " BasicConv2d-130 [-1, 64, 138, 12] 0\n", + " ZeroPad2d-131 [-1, 64, 140, 14] 0\n", + " Conv2d-132 [-1, 128, 138, 12] 73,728\n", + " BatchNorm2d-133 [-1, 128, 138, 12] 256\n", + " BasicConv2d-134 [-1, 128, 138, 12] 0\n", + " BatchNorm2d-135 [-1, 298, 138, 12] 596\n", + "Multi_2D_CNN_block-136 [-1, 298, 138, 12] 0\n", + " ZeroPad2d-137 [-1, 298, 138, 12] 0\n", + " Conv2d-138 [-1, 42, 138, 12] 12,516\n", + " BatchNorm2d-139 [-1, 42, 138, 12] 84\n", + " BasicConv2d-140 [-1, 42, 138, 12] 0\n", + " ZeroPad2d-141 [-1, 298, 138, 12] 0\n", + " Conv2d-142 [-1, 64, 138, 12] 19,072\n", + " BatchNorm2d-143 [-1, 64, 138, 12] 128\n", + " BasicConv2d-144 [-1, 64, 138, 12] 0\n", + " ZeroPad2d-145 [-1, 64, 140, 14] 0\n", + " Conv2d-146 [-1, 128, 138, 12] 73,728\n", + " BatchNorm2d-147 [-1, 128, 138, 12] 256\n", + " BasicConv2d-148 [-1, 128, 138, 12] 0\n", + " ZeroPad2d-149 [-1, 298, 138, 12] 0\n", + " Conv2d-150 [-1, 42, 138, 12] 12,516\n", + " BatchNorm2d-151 [-1, 42, 138, 12] 84\n", + " BasicConv2d-152 [-1, 42, 138, 12] 0\n", + " ZeroPad2d-153 [-1, 42, 140, 14] 0\n", + " Conv2d-154 [-1, 64, 138, 12] 24,192\n", + " BatchNorm2d-155 [-1, 64, 138, 12] 128\n", + " BasicConv2d-156 [-1, 64, 138, 12] 0\n", + " ZeroPad2d-157 [-1, 64, 140, 14] 0\n", + " Conv2d-158 [-1, 128, 138, 12] 73,728\n", + " BatchNorm2d-159 [-1, 128, 138, 12] 256\n", + " BasicConv2d-160 [-1, 128, 138, 12] 0\n", + " BatchNorm2d-161 [-1, 298, 138, 12] 596\n", + "Multi_2D_CNN_block-162 [-1, 298, 138, 12] 0\n", + " MaxPool2d-163 [-1, 298, 69, 12] 0\n", + " ZeroPad2d-164 [-1, 298, 69, 12] 0\n", + " Conv2d-165 [-1, 64, 69, 12] 19,072\n", + " BatchNorm2d-166 [-1, 64, 69, 12] 128\n", + " BasicConv2d-167 [-1, 64, 69, 12] 0\n", + " ZeroPad2d-168 [-1, 298, 69, 12] 0\n", + " Conv2d-169 [-1, 96, 69, 12] 28,608\n", + " BatchNorm2d-170 [-1, 96, 69, 12] 192\n", + " BasicConv2d-171 [-1, 96, 69, 12] 0\n", + " ZeroPad2d-172 [-1, 96, 71, 14] 0\n", + " Conv2d-173 [-1, 192, 69, 12] 165,888\n", + " BatchNorm2d-174 [-1, 192, 69, 12] 384\n", + " BasicConv2d-175 [-1, 192, 69, 12] 0\n", + " ZeroPad2d-176 [-1, 298, 69, 12] 0\n", + " Conv2d-177 [-1, 64, 69, 12] 19,072\n", + " BatchNorm2d-178 [-1, 64, 69, 12] 128\n", + " BasicConv2d-179 [-1, 64, 69, 12] 0\n", + " ZeroPad2d-180 [-1, 64, 71, 14] 0\n", + " Conv2d-181 [-1, 96, 69, 12] 55,296\n", + " BatchNorm2d-182 [-1, 96, 69, 12] 192\n", + " BasicConv2d-183 [-1, 96, 69, 12] 0\n", + " ZeroPad2d-184 [-1, 96, 71, 14] 0\n", + " Conv2d-185 [-1, 192, 69, 12] 165,888\n", + " BatchNorm2d-186 [-1, 192, 69, 12] 384\n", + " BasicConv2d-187 [-1, 192, 69, 12] 0\n", + " BatchNorm2d-188 [-1, 448, 69, 12] 896\n", + "Multi_2D_CNN_block-189 [-1, 448, 69, 12] 0\n", + " ZeroPad2d-190 [-1, 448, 69, 12] 0\n", + " Conv2d-191 [-1, 64, 69, 12] 28,672\n", + " BatchNorm2d-192 [-1, 64, 69, 12] 128\n", + " BasicConv2d-193 [-1, 64, 69, 12] 0\n", + " ZeroPad2d-194 [-1, 448, 69, 12] 0\n", + " Conv2d-195 [-1, 96, 69, 12] 43,008\n", + " BatchNorm2d-196 [-1, 96, 69, 12] 192\n", + " BasicConv2d-197 [-1, 96, 69, 12] 0\n", + " ZeroPad2d-198 [-1, 96, 71, 14] 0\n", + " Conv2d-199 [-1, 192, 69, 12] 165,888\n", + " BatchNorm2d-200 [-1, 192, 69, 12] 384\n", + " BasicConv2d-201 [-1, 192, 69, 12] 0\n", + " ZeroPad2d-202 [-1, 448, 69, 12] 0\n", + " Conv2d-203 [-1, 64, 69, 12] 28,672\n", + " BatchNorm2d-204 [-1, 64, 69, 12] 128\n", + " BasicConv2d-205 [-1, 64, 69, 12] 0\n", + " ZeroPad2d-206 [-1, 64, 71, 14] 0\n", + " Conv2d-207 [-1, 96, 69, 12] 55,296\n", + " BatchNorm2d-208 [-1, 96, 69, 12] 192\n", + " BasicConv2d-209 [-1, 96, 69, 12] 0\n", + " ZeroPad2d-210 [-1, 96, 71, 14] 0\n", + " Conv2d-211 [-1, 192, 69, 12] 165,888\n", + " BatchNorm2d-212 [-1, 192, 69, 12] 384\n", + " BasicConv2d-213 [-1, 192, 69, 12] 0\n", + " BatchNorm2d-214 [-1, 448, 69, 12] 896\n", + "Multi_2D_CNN_block-215 [-1, 448, 69, 12] 0\n", + " ZeroPad2d-216 [-1, 448, 69, 12] 0\n", + " Conv2d-217 [-1, 85, 69, 12] 38,080\n", + " BatchNorm2d-218 [-1, 85, 69, 12] 170\n", + " BasicConv2d-219 [-1, 85, 69, 12] 0\n", + " ZeroPad2d-220 [-1, 448, 69, 12] 0\n", + " Conv2d-221 [-1, 128, 69, 12] 57,344\n", + " BatchNorm2d-222 [-1, 128, 69, 12] 256\n", + " BasicConv2d-223 [-1, 128, 69, 12] 0\n", + " ZeroPad2d-224 [-1, 128, 71, 14] 0\n", + " Conv2d-225 [-1, 256, 69, 12] 294,912\n", + " BatchNorm2d-226 [-1, 256, 69, 12] 512\n", + " BasicConv2d-227 [-1, 256, 69, 12] 0\n", + " ZeroPad2d-228 [-1, 448, 69, 12] 0\n", + " Conv2d-229 [-1, 85, 69, 12] 38,080\n", + " BatchNorm2d-230 [-1, 85, 69, 12] 170\n", + " BasicConv2d-231 [-1, 85, 69, 12] 0\n", + " ZeroPad2d-232 [-1, 85, 71, 14] 0\n", + " Conv2d-233 [-1, 128, 69, 12] 97,920\n", + " BatchNorm2d-234 [-1, 128, 69, 12] 256\n", + " BasicConv2d-235 [-1, 128, 69, 12] 0\n", + " ZeroPad2d-236 [-1, 128, 71, 14] 0\n", + " Conv2d-237 [-1, 256, 69, 12] 294,912\n", + " BatchNorm2d-238 [-1, 256, 69, 12] 512\n", + " BasicConv2d-239 [-1, 256, 69, 12] 0\n", + " BatchNorm2d-240 [-1, 597, 69, 12] 1,194\n", + "Multi_2D_CNN_block-241 [-1, 597, 69, 12] 0\n", + " MaxPool2d-242 [-1, 597, 34, 12] 0\n", + " ZeroPad2d-243 [-1, 597, 34, 12] 0\n", + " Conv2d-244 [-1, 106, 34, 12] 63,282\n", + " BatchNorm2d-245 [-1, 106, 34, 12] 212\n", + " BasicConv2d-246 [-1, 106, 34, 12] 0\n", + " ZeroPad2d-247 [-1, 597, 34, 12] 0\n", + " Conv2d-248 [-1, 160, 34, 12] 95,520\n", + " BatchNorm2d-249 [-1, 160, 34, 12] 320\n", + " BasicConv2d-250 [-1, 160, 34, 12] 0\n", + " ZeroPad2d-251 [-1, 160, 36, 14] 0\n", + " Conv2d-252 [-1, 320, 34, 12] 460,800\n", + " BatchNorm2d-253 [-1, 320, 34, 12] 640\n", + " BasicConv2d-254 [-1, 320, 34, 12] 0\n", + " ZeroPad2d-255 [-1, 597, 34, 12] 0\n", + " Conv2d-256 [-1, 106, 34, 12] 63,282\n", + " BatchNorm2d-257 [-1, 106, 34, 12] 212\n", + " BasicConv2d-258 [-1, 106, 34, 12] 0\n", + " ZeroPad2d-259 [-1, 106, 36, 14] 0\n", + " Conv2d-260 [-1, 160, 34, 12] 152,640\n", + " BatchNorm2d-261 [-1, 160, 34, 12] 320\n", + " BasicConv2d-262 [-1, 160, 34, 12] 0\n", + " ZeroPad2d-263 [-1, 160, 36, 14] 0\n", + " Conv2d-264 [-1, 320, 34, 12] 460,800\n", + " BatchNorm2d-265 [-1, 320, 34, 12] 640\n", + " BasicConv2d-266 [-1, 320, 34, 12] 0\n", + " BatchNorm2d-267 [-1, 746, 34, 12] 1,492\n", + "Multi_2D_CNN_block-268 [-1, 746, 34, 12] 0\n", + " ZeroPad2d-269 [-1, 746, 34, 12] 0\n", + " Conv2d-270 [-1, 128, 34, 12] 95,488\n", + " BatchNorm2d-271 [-1, 128, 34, 12] 256\n", + " BasicConv2d-272 [-1, 128, 34, 12] 0\n", + " ZeroPad2d-273 [-1, 746, 34, 12] 0\n", + " Conv2d-274 [-1, 192, 34, 12] 143,232\n", + " BatchNorm2d-275 [-1, 192, 34, 12] 384\n", + " BasicConv2d-276 [-1, 192, 34, 12] 0\n", + " ZeroPad2d-277 [-1, 192, 36, 14] 0\n", + " Conv2d-278 [-1, 384, 34, 12] 663,552\n", + " BatchNorm2d-279 [-1, 384, 34, 12] 768\n", + " BasicConv2d-280 [-1, 384, 34, 12] 0\n", + " ZeroPad2d-281 [-1, 746, 34, 12] 0\n", + " Conv2d-282 [-1, 128, 34, 12] 95,488\n", + " BatchNorm2d-283 [-1, 128, 34, 12] 256\n", + " BasicConv2d-284 [-1, 128, 34, 12] 0\n", + " ZeroPad2d-285 [-1, 128, 36, 14] 0\n", + " Conv2d-286 [-1, 192, 34, 12] 221,184\n", + " BatchNorm2d-287 [-1, 192, 34, 12] 384\n", + " BasicConv2d-288 [-1, 192, 34, 12] 0\n", + " ZeroPad2d-289 [-1, 192, 36, 14] 0\n", + " Conv2d-290 [-1, 384, 34, 12] 663,552\n", + " BatchNorm2d-291 [-1, 384, 34, 12] 768\n", + " BasicConv2d-292 [-1, 384, 34, 12] 0\n", + " BatchNorm2d-293 [-1, 896, 34, 12] 1,792\n", + "Multi_2D_CNN_block-294 [-1, 896, 34, 12] 0\n", + " ZeroPad2d-295 [-1, 896, 34, 12] 0\n", + " Conv2d-296 [-1, 149, 34, 12] 133,504\n", + " BatchNorm2d-297 [-1, 149, 34, 12] 298\n", + " BasicConv2d-298 [-1, 149, 34, 12] 0\n", + " ZeroPad2d-299 [-1, 896, 34, 12] 0\n", + " Conv2d-300 [-1, 224, 34, 12] 200,704\n", + " BatchNorm2d-301 [-1, 224, 34, 12] 448\n", + " BasicConv2d-302 [-1, 224, 34, 12] 0\n", + " ZeroPad2d-303 [-1, 224, 36, 14] 0\n", + " Conv2d-304 [-1, 448, 34, 12] 903,168\n", + " BatchNorm2d-305 [-1, 448, 34, 12] 896\n", + " BasicConv2d-306 [-1, 448, 34, 12] 0\n", + " ZeroPad2d-307 [-1, 896, 34, 12] 0\n", + " Conv2d-308 [-1, 149, 34, 12] 133,504\n", + " BatchNorm2d-309 [-1, 149, 34, 12] 298\n", + " BasicConv2d-310 [-1, 149, 34, 12] 0\n", + " ZeroPad2d-311 [-1, 149, 36, 14] 0\n", + " Conv2d-312 [-1, 224, 34, 12] 300,384\n", + " BatchNorm2d-313 [-1, 224, 34, 12] 448\n", + " BasicConv2d-314 [-1, 224, 34, 12] 0\n", + " ZeroPad2d-315 [-1, 224, 36, 14] 0\n", + " Conv2d-316 [-1, 448, 34, 12] 903,168\n", + " BatchNorm2d-317 [-1, 448, 34, 12] 896\n", + " BasicConv2d-318 [-1, 448, 34, 12] 0\n", + " BatchNorm2d-319 [-1, 1045, 34, 12] 2,090\n", + "Multi_2D_CNN_block-320 [-1, 1045, 34, 12] 0\n", + " MaxPool2d-321 [-1, 1045, 17, 12] 0\n", + " ZeroPad2d-322 [-1, 1045, 17, 12] 0\n", + " Conv2d-323 [-1, 170, 17, 12] 177,650\n", + " BatchNorm2d-324 [-1, 170, 17, 12] 340\n", + " BasicConv2d-325 [-1, 170, 17, 12] 0\n", + " ZeroPad2d-326 [-1, 1045, 17, 12] 0\n", + " Conv2d-327 [-1, 256, 17, 12] 267,520\n", + " BatchNorm2d-328 [-1, 256, 17, 12] 512\n", + " BasicConv2d-329 [-1, 256, 17, 12] 0\n", + " ZeroPad2d-330 [-1, 256, 19, 14] 0\n", + " Conv2d-331 [-1, 512, 17, 12] 1,179,648\n", + " BatchNorm2d-332 [-1, 512, 17, 12] 1,024\n", + " BasicConv2d-333 [-1, 512, 17, 12] 0\n", + " ZeroPad2d-334 [-1, 1045, 17, 12] 0\n", + " Conv2d-335 [-1, 170, 17, 12] 177,650\n", + " BatchNorm2d-336 [-1, 170, 17, 12] 340\n", + " BasicConv2d-337 [-1, 170, 17, 12] 0\n", + " ZeroPad2d-338 [-1, 170, 19, 14] 0\n", + " Conv2d-339 [-1, 256, 17, 12] 391,680\n", + " BatchNorm2d-340 [-1, 256, 17, 12] 512\n", + " BasicConv2d-341 [-1, 256, 17, 12] 0\n", + " ZeroPad2d-342 [-1, 256, 19, 14] 0\n", + " Conv2d-343 [-1, 512, 17, 12] 1,179,648\n", + " BatchNorm2d-344 [-1, 512, 17, 12] 1,024\n", + " BasicConv2d-345 [-1, 512, 17, 12] 0\n", + " BatchNorm2d-346 [-1, 1194, 17, 12] 2,388\n", + "Multi_2D_CNN_block-347 [-1, 1194, 17, 12] 0\n", + " ZeroPad2d-348 [-1, 1194, 17, 12] 0\n", + " Conv2d-349 [-1, 170, 17, 12] 202,980\n", + " BatchNorm2d-350 [-1, 170, 17, 12] 340\n", + " BasicConv2d-351 [-1, 170, 17, 12] 0\n", + " ZeroPad2d-352 [-1, 1194, 17, 12] 0\n", + " Conv2d-353 [-1, 256, 17, 12] 305,664\n", + " BatchNorm2d-354 [-1, 256, 17, 12] 512\n", + " BasicConv2d-355 [-1, 256, 17, 12] 0\n", + " ZeroPad2d-356 [-1, 256, 19, 14] 0\n", + " Conv2d-357 [-1, 512, 17, 12] 1,179,648\n", + " BatchNorm2d-358 [-1, 512, 17, 12] 1,024\n", + " BasicConv2d-359 [-1, 512, 17, 12] 0\n", + " ZeroPad2d-360 [-1, 1194, 17, 12] 0\n", + " Conv2d-361 [-1, 170, 17, 12] 202,980\n", + " BatchNorm2d-362 [-1, 170, 17, 12] 340\n", + " BasicConv2d-363 [-1, 170, 17, 12] 0\n", + " ZeroPad2d-364 [-1, 170, 19, 14] 0\n", + " Conv2d-365 [-1, 256, 17, 12] 391,680\n", + " BatchNorm2d-366 [-1, 256, 17, 12] 512\n", + " BasicConv2d-367 [-1, 256, 17, 12] 0\n", + " ZeroPad2d-368 [-1, 256, 19, 14] 0\n", + " Conv2d-369 [-1, 512, 17, 12] 1,179,648\n", + " BatchNorm2d-370 [-1, 512, 17, 12] 1,024\n", + " BasicConv2d-371 [-1, 512, 17, 12] 0\n", + " BatchNorm2d-372 [-1, 1194, 17, 12] 2,388\n", + "Multi_2D_CNN_block-373 [-1, 1194, 17, 12] 0\n", + " ZeroPad2d-374 [-1, 1194, 17, 12] 0\n", + " Conv2d-375 [-1, 170, 17, 12] 202,980\n", + " BatchNorm2d-376 [-1, 170, 17, 12] 340\n", + " BasicConv2d-377 [-1, 170, 17, 12] 0\n", + " ZeroPad2d-378 [-1, 1194, 17, 12] 0\n", + " Conv2d-379 [-1, 256, 17, 12] 305,664\n", + " BatchNorm2d-380 [-1, 256, 17, 12] 512\n", + " BasicConv2d-381 [-1, 256, 17, 12] 0\n", + " ZeroPad2d-382 [-1, 256, 19, 14] 0\n", + " Conv2d-383 [-1, 512, 17, 12] 1,179,648\n", + " BatchNorm2d-384 [-1, 512, 17, 12] 1,024\n", + " BasicConv2d-385 [-1, 512, 17, 12] 0\n", + " ZeroPad2d-386 [-1, 1194, 17, 12] 0\n", + " Conv2d-387 [-1, 170, 17, 12] 202,980\n", + " BatchNorm2d-388 [-1, 170, 17, 12] 340\n", + " BasicConv2d-389 [-1, 170, 17, 12] 0\n", + " ZeroPad2d-390 [-1, 170, 19, 14] 0\n", + " Conv2d-391 [-1, 256, 17, 12] 391,680\n", + " BatchNorm2d-392 [-1, 256, 17, 12] 512\n", + " BasicConv2d-393 [-1, 256, 17, 12] 0\n", + " ZeroPad2d-394 [-1, 256, 19, 14] 0\n", + " Conv2d-395 [-1, 512, 17, 12] 1,179,648\n", + " BatchNorm2d-396 [-1, 512, 17, 12] 1,024\n", + " BasicConv2d-397 [-1, 512, 17, 12] 0\n", + " BatchNorm2d-398 [-1, 1194, 17, 12] 2,388\n", + "Multi_2D_CNN_block-399 [-1, 1194, 17, 12] 0\n", + " MaxPool2d-400 [-1, 1194, 8, 12] 0\n", + " ZeroPad2d-401 [-1, 1194, 8, 12] 0\n", + " Conv2d-402 [-1, 256, 8, 12] 305,664\n", + " BatchNorm2d-403 [-1, 256, 8, 12] 512\n", + " BasicConv2d-404 [-1, 256, 8, 12] 0\n", + " ZeroPad2d-405 [-1, 1194, 8, 12] 0\n", + " Conv2d-406 [-1, 384, 8, 12] 458,496\n", + " BatchNorm2d-407 [-1, 384, 8, 12] 768\n", + " BasicConv2d-408 [-1, 384, 8, 12] 0\n", + " ZeroPad2d-409 [-1, 384, 10, 14] 0\n", + " Conv2d-410 [-1, 768, 8, 12] 2,654,208\n", + " BatchNorm2d-411 [-1, 768, 8, 12] 1,536\n", + " BasicConv2d-412 [-1, 768, 8, 12] 0\n", + " ZeroPad2d-413 [-1, 1194, 8, 12] 0\n", + " Conv2d-414 [-1, 256, 8, 12] 305,664\n", + " BatchNorm2d-415 [-1, 256, 8, 12] 512\n", + " BasicConv2d-416 [-1, 256, 8, 12] 0\n", + " ZeroPad2d-417 [-1, 256, 10, 14] 0\n", + " Conv2d-418 [-1, 384, 8, 12] 884,736\n", + " BatchNorm2d-419 [-1, 384, 8, 12] 768\n", + " BasicConv2d-420 [-1, 384, 8, 12] 0\n", + " ZeroPad2d-421 [-1, 384, 10, 14] 0\n", + " Conv2d-422 [-1, 768, 8, 12] 2,654,208\n", + " BatchNorm2d-423 [-1, 768, 8, 12] 1,536\n", + " BasicConv2d-424 [-1, 768, 8, 12] 0\n", + " BatchNorm2d-425 [-1, 1792, 8, 12] 3,584\n", + "Multi_2D_CNN_block-426 [-1, 1792, 8, 12] 0\n", + " ZeroPad2d-427 [-1, 1792, 8, 12] 0\n", + " Conv2d-428 [-1, 298, 8, 12] 534,016\n", + " BatchNorm2d-429 [-1, 298, 8, 12] 596\n", + " BasicConv2d-430 [-1, 298, 8, 12] 0\n", + " ZeroPad2d-431 [-1, 1792, 8, 12] 0\n", + " Conv2d-432 [-1, 448, 8, 12] 802,816\n", + " BatchNorm2d-433 [-1, 448, 8, 12] 896\n", + " BasicConv2d-434 [-1, 448, 8, 12] 0\n", + " ZeroPad2d-435 [-1, 448, 10, 14] 0\n", + " Conv2d-436 [-1, 896, 8, 12] 3,612,672\n", + " BatchNorm2d-437 [-1, 896, 8, 12] 1,792\n", + " BasicConv2d-438 [-1, 896, 8, 12] 0\n", + " ZeroPad2d-439 [-1, 1792, 8, 12] 0\n", + " Conv2d-440 [-1, 298, 8, 12] 534,016\n", + " BatchNorm2d-441 [-1, 298, 8, 12] 596\n", + " BasicConv2d-442 [-1, 298, 8, 12] 0\n", + " ZeroPad2d-443 [-1, 298, 10, 14] 0\n", + " Conv2d-444 [-1, 448, 8, 12] 1,201,536\n", + " BatchNorm2d-445 [-1, 448, 8, 12] 896\n", + " BasicConv2d-446 [-1, 448, 8, 12] 0\n", + " ZeroPad2d-447 [-1, 448, 10, 14] 0\n", + " Conv2d-448 [-1, 896, 8, 12] 3,612,672\n", + " BatchNorm2d-449 [-1, 896, 8, 12] 1,792\n", + " BasicConv2d-450 [-1, 896, 8, 12] 0\n", + " BatchNorm2d-451 [-1, 2090, 8, 12] 4,180\n", + "Multi_2D_CNN_block-452 [-1, 2090, 8, 12] 0\n", + " ZeroPad2d-453 [-1, 2090, 8, 12] 0\n", + " Conv2d-454 [-1, 341, 8, 12] 712,690\n", + " BatchNorm2d-455 [-1, 341, 8, 12] 682\n", + " BasicConv2d-456 [-1, 341, 8, 12] 0\n", + " ZeroPad2d-457 [-1, 2090, 8, 12] 0\n", + " Conv2d-458 [-1, 512, 8, 12] 1,070,080\n", + " BatchNorm2d-459 [-1, 512, 8, 12] 1,024\n", + " BasicConv2d-460 [-1, 512, 8, 12] 0\n", + " ZeroPad2d-461 [-1, 512, 10, 14] 0\n", + " Conv2d-462 [-1, 1024, 8, 12] 4,718,592\n", + " BatchNorm2d-463 [-1, 1024, 8, 12] 2,048\n", + " BasicConv2d-464 [-1, 1024, 8, 12] 0\n", + " ZeroPad2d-465 [-1, 2090, 8, 12] 0\n", + " Conv2d-466 [-1, 341, 8, 12] 712,690\n", + " BatchNorm2d-467 [-1, 341, 8, 12] 682\n", + " BasicConv2d-468 [-1, 341, 8, 12] 0\n", + " ZeroPad2d-469 [-1, 341, 10, 14] 0\n", + " Conv2d-470 [-1, 512, 8, 12] 1,571,328\n", + " BatchNorm2d-471 [-1, 512, 8, 12] 1,024\n", + " BasicConv2d-472 [-1, 512, 8, 12] 0\n", + " ZeroPad2d-473 [-1, 512, 10, 14] 0\n", + " Conv2d-474 [-1, 1024, 8, 12] 4,718,592\n", + " BatchNorm2d-475 [-1, 1024, 8, 12] 2,048\n", + " BasicConv2d-476 [-1, 1024, 8, 12] 0\n", + " BatchNorm2d-477 [-1, 2389, 8, 12] 4,778\n", + "Multi_2D_CNN_block-478 [-1, 2389, 8, 12] 0\n", + "AdaptiveAvgPool2d-479 [-1, 2389, 1, 1] 0\n", + " Flatten-480 [-1, 2389] 0\n", + " Dropout-481 [-1, 2389] 0\n", + " Linear-482 [-1, 1] 2,390\n", + "================================================================\n", + "Total params: 49,721,334\n", + "Trainable params: 49,721,334\n", + "Non-trainable params: 0\n", + "----------------------------------------------------------------\n", + "Input size (MB): 0.11\n", + "Forward/backward pass size (MB): 884.62\n", + "Params size (MB): 189.67\n", + "Estimated Total Size (MB): 1074.40\n", + "----------------------------------------------------------------\n" + ] + } + ], + "source": [ + "summary(Inception_Tabular_15_3(), input_size=(1, 2500, 12))" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------------------------------------\n", + " Layer (type) Output Shape Param #\n", + "================================================================\n", + " ZeroPad2d-1 [-1, 1, 2506, 14] 0\n", + " Conv2d-2 [-1, 64, 1250, 12] 1,344\n", + " BatchNorm2d-3 [-1, 64, 1250, 12] 128\n", + " BasicConv2d-4 [-1, 64, 1250, 12] 0\n", + " ZeroPad2d-5 [-1, 64, 1250, 12] 0\n", + " Conv2d-6 [-1, 21, 1250, 12] 1,344\n", + " BatchNorm2d-7 [-1, 21, 1250, 12] 42\n", + " BasicConv2d-8 [-1, 21, 1250, 12] 0\n", + " ZeroPad2d-9 [-1, 64, 1250, 12] 0\n", + " Conv2d-10 [-1, 32, 1250, 12] 2,048\n", + " BatchNorm2d-11 [-1, 32, 1250, 12] 64\n", + " BasicConv2d-12 [-1, 32, 1250, 12] 0\n", + " ZeroPad2d-13 [-1, 32, 1252, 14] 0\n", + " Conv2d-14 [-1, 64, 1250, 12] 18,432\n", + " BatchNorm2d-15 [-1, 64, 1250, 12] 128\n", + " BasicConv2d-16 [-1, 64, 1250, 12] 0\n", + " ZeroPad2d-17 [-1, 64, 1250, 12] 0\n", + " Conv2d-18 [-1, 21, 1250, 12] 1,344\n", + " BatchNorm2d-19 [-1, 21, 1250, 12] 42\n", + " BasicConv2d-20 [-1, 21, 1250, 12] 0\n", + " ZeroPad2d-21 [-1, 21, 1252, 14] 0\n", + " Conv2d-22 [-1, 32, 1250, 12] 6,048\n", + " BatchNorm2d-23 [-1, 32, 1250, 12] 64\n", + " BasicConv2d-24 [-1, 32, 1250, 12] 0\n", + " ZeroPad2d-25 [-1, 32, 1252, 14] 0\n", + " Conv2d-26 [-1, 64, 1250, 12] 18,432\n", + " BatchNorm2d-27 [-1, 64, 1250, 12] 128\n", + " BasicConv2d-28 [-1, 64, 1250, 12] 0\n", + " BatchNorm2d-29 [-1, 149, 1250, 12] 298\n", + "Multi_2D_CNN_block-30 [-1, 149, 1250, 12] 0\n", + " ZeroPad2d-31 [-1, 149, 1250, 12] 0\n", + " Conv2d-32 [-1, 21, 1250, 12] 3,129\n", + " BatchNorm2d-33 [-1, 21, 1250, 12] 42\n", + " BasicConv2d-34 [-1, 21, 1250, 12] 0\n", + " ZeroPad2d-35 [-1, 149, 1250, 12] 0\n", + " Conv2d-36 [-1, 32, 1250, 12] 4,768\n", + " BatchNorm2d-37 [-1, 32, 1250, 12] 64\n", + " BasicConv2d-38 [-1, 32, 1250, 12] 0\n", + " ZeroPad2d-39 [-1, 32, 1252, 14] 0\n", + " Conv2d-40 [-1, 64, 1250, 12] 18,432\n", + " BatchNorm2d-41 [-1, 64, 1250, 12] 128\n", + " BasicConv2d-42 [-1, 64, 1250, 12] 0\n", + " ZeroPad2d-43 [-1, 149, 1250, 12] 0\n", + " Conv2d-44 [-1, 21, 1250, 12] 3,129\n", + " BatchNorm2d-45 [-1, 21, 1250, 12] 42\n", + " BasicConv2d-46 [-1, 21, 1250, 12] 0\n", + " ZeroPad2d-47 [-1, 21, 1252, 14] 0\n", + " Conv2d-48 [-1, 32, 1250, 12] 6,048\n", + " BatchNorm2d-49 [-1, 32, 1250, 12] 64\n", + " BasicConv2d-50 [-1, 32, 1250, 12] 0\n", + " ZeroPad2d-51 [-1, 32, 1252, 14] 0\n", + " Conv2d-52 [-1, 64, 1250, 12] 18,432\n", + " BatchNorm2d-53 [-1, 64, 1250, 12] 128\n", + " BasicConv2d-54 [-1, 64, 1250, 12] 0\n", + " BatchNorm2d-55 [-1, 149, 1250, 12] 298\n", + "Multi_2D_CNN_block-56 [-1, 149, 1250, 12] 0\n", + " MaxPool2d-57 [-1, 149, 416, 12] 0\n", + " ZeroPad2d-58 [-1, 149, 416, 12] 0\n", + " Conv2d-59 [-1, 32, 416, 12] 4,768\n", + " BatchNorm2d-60 [-1, 32, 416, 12] 64\n", + " BasicConv2d-61 [-1, 32, 416, 12] 0\n", + " ZeroPad2d-62 [-1, 149, 416, 12] 0\n", + " Conv2d-63 [-1, 48, 416, 12] 7,152\n", + " BatchNorm2d-64 [-1, 48, 416, 12] 96\n", + " BasicConv2d-65 [-1, 48, 416, 12] 0\n", + " ZeroPad2d-66 [-1, 48, 418, 14] 0\n", + " Conv2d-67 [-1, 96, 416, 12] 41,472\n", + " BatchNorm2d-68 [-1, 96, 416, 12] 192\n", + " BasicConv2d-69 [-1, 96, 416, 12] 0\n", + " ZeroPad2d-70 [-1, 149, 416, 12] 0\n", + " Conv2d-71 [-1, 32, 416, 12] 4,768\n", + " BatchNorm2d-72 [-1, 32, 416, 12] 64\n", + " BasicConv2d-73 [-1, 32, 416, 12] 0\n", + " ZeroPad2d-74 [-1, 32, 418, 14] 0\n", + " Conv2d-75 [-1, 48, 416, 12] 13,824\n", + " BatchNorm2d-76 [-1, 48, 416, 12] 96\n", + " BasicConv2d-77 [-1, 48, 416, 12] 0\n", + " ZeroPad2d-78 [-1, 48, 418, 14] 0\n", + " Conv2d-79 [-1, 96, 416, 12] 41,472\n", + " BatchNorm2d-80 [-1, 96, 416, 12] 192\n", + " BasicConv2d-81 [-1, 96, 416, 12] 0\n", + " BatchNorm2d-82 [-1, 224, 416, 12] 448\n", + "Multi_2D_CNN_block-83 [-1, 224, 416, 12] 0\n", + " ZeroPad2d-84 [-1, 224, 416, 12] 0\n", + " Conv2d-85 [-1, 32, 416, 12] 7,168\n", + " BatchNorm2d-86 [-1, 32, 416, 12] 64\n", + " BasicConv2d-87 [-1, 32, 416, 12] 0\n", + " ZeroPad2d-88 [-1, 224, 416, 12] 0\n", + " Conv2d-89 [-1, 48, 416, 12] 10,752\n", + " BatchNorm2d-90 [-1, 48, 416, 12] 96\n", + " BasicConv2d-91 [-1, 48, 416, 12] 0\n", + " ZeroPad2d-92 [-1, 48, 418, 14] 0\n", + " Conv2d-93 [-1, 96, 416, 12] 41,472\n", + " BatchNorm2d-94 [-1, 96, 416, 12] 192\n", + " BasicConv2d-95 [-1, 96, 416, 12] 0\n", + " ZeroPad2d-96 [-1, 224, 416, 12] 0\n", + " Conv2d-97 [-1, 32, 416, 12] 7,168\n", + " BatchNorm2d-98 [-1, 32, 416, 12] 64\n", + " BasicConv2d-99 [-1, 32, 416, 12] 0\n", + " ZeroPad2d-100 [-1, 32, 418, 14] 0\n", + " Conv2d-101 [-1, 48, 416, 12] 13,824\n", + " BatchNorm2d-102 [-1, 48, 416, 12] 96\n", + " BasicConv2d-103 [-1, 48, 416, 12] 0\n", + " ZeroPad2d-104 [-1, 48, 418, 14] 0\n", + " Conv2d-105 [-1, 96, 416, 12] 41,472\n", + " BatchNorm2d-106 [-1, 96, 416, 12] 192\n", + " BasicConv2d-107 [-1, 96, 416, 12] 0\n", + " BatchNorm2d-108 [-1, 224, 416, 12] 448\n", + "Multi_2D_CNN_block-109 [-1, 224, 416, 12] 0\n", + " MaxPool2d-110 [-1, 224, 138, 12] 0\n", + " ZeroPad2d-111 [-1, 224, 138, 12] 0\n", + " Conv2d-112 [-1, 42, 138, 12] 9,408\n", + " BatchNorm2d-113 [-1, 42, 138, 12] 84\n", + " BasicConv2d-114 [-1, 42, 138, 12] 0\n", + " ZeroPad2d-115 [-1, 224, 138, 12] 0\n", + " Conv2d-116 [-1, 64, 138, 12] 14,336\n", + " BatchNorm2d-117 [-1, 64, 138, 12] 128\n", + " BasicConv2d-118 [-1, 64, 138, 12] 0\n", + " ZeroPad2d-119 [-1, 64, 140, 14] 0\n", + " Conv2d-120 [-1, 128, 138, 12] 73,728\n", + " BatchNorm2d-121 [-1, 128, 138, 12] 256\n", + " BasicConv2d-122 [-1, 128, 138, 12] 0\n", + " ZeroPad2d-123 [-1, 224, 138, 12] 0\n", + " Conv2d-124 [-1, 42, 138, 12] 9,408\n", + " BatchNorm2d-125 [-1, 42, 138, 12] 84\n", + " BasicConv2d-126 [-1, 42, 138, 12] 0\n", + " ZeroPad2d-127 [-1, 42, 140, 14] 0\n", + " Conv2d-128 [-1, 64, 138, 12] 24,192\n", + " BatchNorm2d-129 [-1, 64, 138, 12] 128\n", + " BasicConv2d-130 [-1, 64, 138, 12] 0\n", + " ZeroPad2d-131 [-1, 64, 140, 14] 0\n", + " Conv2d-132 [-1, 128, 138, 12] 73,728\n", + " BatchNorm2d-133 [-1, 128, 138, 12] 256\n", + " BasicConv2d-134 [-1, 128, 138, 12] 0\n", + " BatchNorm2d-135 [-1, 298, 138, 12] 596\n", + "Multi_2D_CNN_block-136 [-1, 298, 138, 12] 0\n", + " ZeroPad2d-137 [-1, 298, 138, 12] 0\n", + " Conv2d-138 [-1, 42, 138, 12] 12,516\n", + " BatchNorm2d-139 [-1, 42, 138, 12] 84\n", + " BasicConv2d-140 [-1, 42, 138, 12] 0\n", + " ZeroPad2d-141 [-1, 298, 138, 12] 0\n", + " Conv2d-142 [-1, 64, 138, 12] 19,072\n", + " BatchNorm2d-143 [-1, 64, 138, 12] 128\n", + " BasicConv2d-144 [-1, 64, 138, 12] 0\n", + " ZeroPad2d-145 [-1, 64, 140, 14] 0\n", + " Conv2d-146 [-1, 128, 138, 12] 73,728\n", + " BatchNorm2d-147 [-1, 128, 138, 12] 256\n", + " BasicConv2d-148 [-1, 128, 138, 12] 0\n", + " ZeroPad2d-149 [-1, 298, 138, 12] 0\n", + " Conv2d-150 [-1, 42, 138, 12] 12,516\n", + " BatchNorm2d-151 [-1, 42, 138, 12] 84\n", + " BasicConv2d-152 [-1, 42, 138, 12] 0\n", + " ZeroPad2d-153 [-1, 42, 140, 14] 0\n", + " Conv2d-154 [-1, 64, 138, 12] 24,192\n", + " BatchNorm2d-155 [-1, 64, 138, 12] 128\n", + " BasicConv2d-156 [-1, 64, 138, 12] 0\n", + " ZeroPad2d-157 [-1, 64, 140, 14] 0\n", + " Conv2d-158 [-1, 128, 138, 12] 73,728\n", + " BatchNorm2d-159 [-1, 128, 138, 12] 256\n", + " BasicConv2d-160 [-1, 128, 138, 12] 0\n", + " BatchNorm2d-161 [-1, 298, 138, 12] 596\n", + "Multi_2D_CNN_block-162 [-1, 298, 138, 12] 0\n", + " MaxPool2d-163 [-1, 298, 69, 12] 0\n", + " ZeroPad2d-164 [-1, 298, 69, 12] 0\n", + " Conv2d-165 [-1, 64, 69, 12] 19,072\n", + " BatchNorm2d-166 [-1, 64, 69, 12] 128\n", + " BasicConv2d-167 [-1, 64, 69, 12] 0\n", + " ZeroPad2d-168 [-1, 298, 69, 12] 0\n", + " Conv2d-169 [-1, 96, 69, 12] 28,608\n", + " BatchNorm2d-170 [-1, 96, 69, 12] 192\n", + " BasicConv2d-171 [-1, 96, 69, 12] 0\n", + " ZeroPad2d-172 [-1, 96, 71, 14] 0\n", + " Conv2d-173 [-1, 192, 69, 12] 165,888\n", + " BatchNorm2d-174 [-1, 192, 69, 12] 384\n", + " BasicConv2d-175 [-1, 192, 69, 12] 0\n", + " ZeroPad2d-176 [-1, 298, 69, 12] 0\n", + " Conv2d-177 [-1, 64, 69, 12] 19,072\n", + " BatchNorm2d-178 [-1, 64, 69, 12] 128\n", + " BasicConv2d-179 [-1, 64, 69, 12] 0\n", + " ZeroPad2d-180 [-1, 64, 71, 14] 0\n", + " Conv2d-181 [-1, 96, 69, 12] 55,296\n", + " BatchNorm2d-182 [-1, 96, 69, 12] 192\n", + " BasicConv2d-183 [-1, 96, 69, 12] 0\n", + " ZeroPad2d-184 [-1, 96, 71, 14] 0\n", + " Conv2d-185 [-1, 192, 69, 12] 165,888\n", + " BatchNorm2d-186 [-1, 192, 69, 12] 384\n", + " BasicConv2d-187 [-1, 192, 69, 12] 0\n", + " BatchNorm2d-188 [-1, 448, 69, 12] 896\n", + "Multi_2D_CNN_block-189 [-1, 448, 69, 12] 0\n", + " ZeroPad2d-190 [-1, 448, 69, 12] 0\n", + " Conv2d-191 [-1, 64, 69, 12] 28,672\n", + " BatchNorm2d-192 [-1, 64, 69, 12] 128\n", + " BasicConv2d-193 [-1, 64, 69, 12] 0\n", + " ZeroPad2d-194 [-1, 448, 69, 12] 0\n", + " Conv2d-195 [-1, 96, 69, 12] 43,008\n", + " BatchNorm2d-196 [-1, 96, 69, 12] 192\n", + " BasicConv2d-197 [-1, 96, 69, 12] 0\n", + " ZeroPad2d-198 [-1, 96, 71, 14] 0\n", + " Conv2d-199 [-1, 192, 69, 12] 165,888\n", + " BatchNorm2d-200 [-1, 192, 69, 12] 384\n", + " BasicConv2d-201 [-1, 192, 69, 12] 0\n", + " ZeroPad2d-202 [-1, 448, 69, 12] 0\n", + " Conv2d-203 [-1, 64, 69, 12] 28,672\n", + " BatchNorm2d-204 [-1, 64, 69, 12] 128\n", + " BasicConv2d-205 [-1, 64, 69, 12] 0\n", + " ZeroPad2d-206 [-1, 64, 71, 14] 0\n", + " Conv2d-207 [-1, 96, 69, 12] 55,296\n", + " BatchNorm2d-208 [-1, 96, 69, 12] 192\n", + " BasicConv2d-209 [-1, 96, 69, 12] 0\n", + " ZeroPad2d-210 [-1, 96, 71, 14] 0\n", + " Conv2d-211 [-1, 192, 69, 12] 165,888\n", + " BatchNorm2d-212 [-1, 192, 69, 12] 384\n", + " BasicConv2d-213 [-1, 192, 69, 12] 0\n", + " BatchNorm2d-214 [-1, 448, 69, 12] 896\n", + "Multi_2D_CNN_block-215 [-1, 448, 69, 12] 0\n", + " ZeroPad2d-216 [-1, 448, 69, 12] 0\n", + " Conv2d-217 [-1, 85, 69, 12] 38,080\n", + " BatchNorm2d-218 [-1, 85, 69, 12] 170\n", + " BasicConv2d-219 [-1, 85, 69, 12] 0\n", + " ZeroPad2d-220 [-1, 448, 69, 12] 0\n", + " Conv2d-221 [-1, 128, 69, 12] 57,344\n", + " BatchNorm2d-222 [-1, 128, 69, 12] 256\n", + " BasicConv2d-223 [-1, 128, 69, 12] 0\n", + " ZeroPad2d-224 [-1, 128, 71, 14] 0\n", + " Conv2d-225 [-1, 256, 69, 12] 294,912\n", + " BatchNorm2d-226 [-1, 256, 69, 12] 512\n", + " BasicConv2d-227 [-1, 256, 69, 12] 0\n", + " ZeroPad2d-228 [-1, 448, 69, 12] 0\n", + " Conv2d-229 [-1, 85, 69, 12] 38,080\n", + " BatchNorm2d-230 [-1, 85, 69, 12] 170\n", + " BasicConv2d-231 [-1, 85, 69, 12] 0\n", + " ZeroPad2d-232 [-1, 85, 71, 14] 0\n", + " Conv2d-233 [-1, 128, 69, 12] 97,920\n", + " BatchNorm2d-234 [-1, 128, 69, 12] 256\n", + " BasicConv2d-235 [-1, 128, 69, 12] 0\n", + " ZeroPad2d-236 [-1, 128, 71, 14] 0\n", + " Conv2d-237 [-1, 256, 69, 12] 294,912\n", + " BatchNorm2d-238 [-1, 256, 69, 12] 512\n", + " BasicConv2d-239 [-1, 256, 69, 12] 0\n", + " BatchNorm2d-240 [-1, 597, 69, 12] 1,194\n", + "Multi_2D_CNN_block-241 [-1, 597, 69, 12] 0\n", + " MaxPool2d-242 [-1, 597, 34, 12] 0\n", + " ZeroPad2d-243 [-1, 597, 34, 12] 0\n", + " Conv2d-244 [-1, 106, 34, 12] 63,282\n", + " BatchNorm2d-245 [-1, 106, 34, 12] 212\n", + " BasicConv2d-246 [-1, 106, 34, 12] 0\n", + " ZeroPad2d-247 [-1, 597, 34, 12] 0\n", + " Conv2d-248 [-1, 160, 34, 12] 95,520\n", + " BatchNorm2d-249 [-1, 160, 34, 12] 320\n", + " BasicConv2d-250 [-1, 160, 34, 12] 0\n", + " ZeroPad2d-251 [-1, 160, 36, 14] 0\n", + " Conv2d-252 [-1, 320, 34, 12] 460,800\n", + " BatchNorm2d-253 [-1, 320, 34, 12] 640\n", + " BasicConv2d-254 [-1, 320, 34, 12] 0\n", + " ZeroPad2d-255 [-1, 597, 34, 12] 0\n", + " Conv2d-256 [-1, 106, 34, 12] 63,282\n", + " BatchNorm2d-257 [-1, 106, 34, 12] 212\n", + " BasicConv2d-258 [-1, 106, 34, 12] 0\n", + " ZeroPad2d-259 [-1, 106, 36, 14] 0\n", + " Conv2d-260 [-1, 160, 34, 12] 152,640\n", + " BatchNorm2d-261 [-1, 160, 34, 12] 320\n", + " BasicConv2d-262 [-1, 160, 34, 12] 0\n", + " ZeroPad2d-263 [-1, 160, 36, 14] 0\n", + " Conv2d-264 [-1, 320, 34, 12] 460,800\n", + " BatchNorm2d-265 [-1, 320, 34, 12] 640\n", + " BasicConv2d-266 [-1, 320, 34, 12] 0\n", + " BatchNorm2d-267 [-1, 746, 34, 12] 1,492\n", + "Multi_2D_CNN_block-268 [-1, 746, 34, 12] 0\n", + " ZeroPad2d-269 [-1, 746, 34, 12] 0\n", + " Conv2d-270 [-1, 128, 34, 12] 95,488\n", + " BatchNorm2d-271 [-1, 128, 34, 12] 256\n", + " BasicConv2d-272 [-1, 128, 34, 12] 0\n", + " ZeroPad2d-273 [-1, 746, 34, 12] 0\n", + " Conv2d-274 [-1, 192, 34, 12] 143,232\n", + " BatchNorm2d-275 [-1, 192, 34, 12] 384\n", + " BasicConv2d-276 [-1, 192, 34, 12] 0\n", + " ZeroPad2d-277 [-1, 192, 36, 14] 0\n", + " Conv2d-278 [-1, 384, 34, 12] 663,552\n", + " BatchNorm2d-279 [-1, 384, 34, 12] 768\n", + " BasicConv2d-280 [-1, 384, 34, 12] 0\n", + " ZeroPad2d-281 [-1, 746, 34, 12] 0\n", + " Conv2d-282 [-1, 128, 34, 12] 95,488\n", + " BatchNorm2d-283 [-1, 128, 34, 12] 256\n", + " BasicConv2d-284 [-1, 128, 34, 12] 0\n", + " ZeroPad2d-285 [-1, 128, 36, 14] 0\n", + " Conv2d-286 [-1, 192, 34, 12] 221,184\n", + " BatchNorm2d-287 [-1, 192, 34, 12] 384\n", + " BasicConv2d-288 [-1, 192, 34, 12] 0\n", + " ZeroPad2d-289 [-1, 192, 36, 14] 0\n", + " Conv2d-290 [-1, 384, 34, 12] 663,552\n", + " BatchNorm2d-291 [-1, 384, 34, 12] 768\n", + " BasicConv2d-292 [-1, 384, 34, 12] 0\n", + " BatchNorm2d-293 [-1, 896, 34, 12] 1,792\n", + "Multi_2D_CNN_block-294 [-1, 896, 34, 12] 0\n", + " ZeroPad2d-295 [-1, 896, 34, 12] 0\n", + " Conv2d-296 [-1, 149, 34, 12] 133,504\n", + " BatchNorm2d-297 [-1, 149, 34, 12] 298\n", + " BasicConv2d-298 [-1, 149, 34, 12] 0\n", + " ZeroPad2d-299 [-1, 896, 34, 12] 0\n", + " Conv2d-300 [-1, 224, 34, 12] 200,704\n", + " BatchNorm2d-301 [-1, 224, 34, 12] 448\n", + " BasicConv2d-302 [-1, 224, 34, 12] 0\n", + " ZeroPad2d-303 [-1, 224, 36, 14] 0\n", + " Conv2d-304 [-1, 448, 34, 12] 903,168\n", + " BatchNorm2d-305 [-1, 448, 34, 12] 896\n", + " BasicConv2d-306 [-1, 448, 34, 12] 0\n", + " ZeroPad2d-307 [-1, 896, 34, 12] 0\n", + " Conv2d-308 [-1, 149, 34, 12] 133,504\n", + " BatchNorm2d-309 [-1, 149, 34, 12] 298\n", + " BasicConv2d-310 [-1, 149, 34, 12] 0\n", + " ZeroPad2d-311 [-1, 149, 36, 14] 0\n", + " Conv2d-312 [-1, 224, 34, 12] 300,384\n", + " BatchNorm2d-313 [-1, 224, 34, 12] 448\n", + " BasicConv2d-314 [-1, 224, 34, 12] 0\n", + " ZeroPad2d-315 [-1, 224, 36, 14] 0\n", + " Conv2d-316 [-1, 448, 34, 12] 903,168\n", + " BatchNorm2d-317 [-1, 448, 34, 12] 896\n", + " BasicConv2d-318 [-1, 448, 34, 12] 0\n", + " BatchNorm2d-319 [-1, 1045, 34, 12] 2,090\n", + "Multi_2D_CNN_block-320 [-1, 1045, 34, 12] 0\n", + " MaxPool2d-321 [-1, 1045, 17, 12] 0\n", + " ZeroPad2d-322 [-1, 1045, 17, 12] 0\n", + " Conv2d-323 [-1, 170, 17, 12] 177,650\n", + " BatchNorm2d-324 [-1, 170, 17, 12] 340\n", + " BasicConv2d-325 [-1, 170, 17, 12] 0\n", + " ZeroPad2d-326 [-1, 1045, 17, 12] 0\n", + " Conv2d-327 [-1, 256, 17, 12] 267,520\n", + " BatchNorm2d-328 [-1, 256, 17, 12] 512\n", + " BasicConv2d-329 [-1, 256, 17, 12] 0\n", + " ZeroPad2d-330 [-1, 256, 19, 14] 0\n", + " Conv2d-331 [-1, 512, 17, 12] 1,179,648\n", + " BatchNorm2d-332 [-1, 512, 17, 12] 1,024\n", + " BasicConv2d-333 [-1, 512, 17, 12] 0\n", + " ZeroPad2d-334 [-1, 1045, 17, 12] 0\n", + " Conv2d-335 [-1, 170, 17, 12] 177,650\n", + " BatchNorm2d-336 [-1, 170, 17, 12] 340\n", + " BasicConv2d-337 [-1, 170, 17, 12] 0\n", + " ZeroPad2d-338 [-1, 170, 19, 14] 0\n", + " Conv2d-339 [-1, 256, 17, 12] 391,680\n", + " BatchNorm2d-340 [-1, 256, 17, 12] 512\n", + " BasicConv2d-341 [-1, 256, 17, 12] 0\n", + " ZeroPad2d-342 [-1, 256, 19, 14] 0\n", + " Conv2d-343 [-1, 512, 17, 12] 1,179,648\n", + " BatchNorm2d-344 [-1, 512, 17, 12] 1,024\n", + " BasicConv2d-345 [-1, 512, 17, 12] 0\n", + " BatchNorm2d-346 [-1, 1194, 17, 12] 2,388\n", + "Multi_2D_CNN_block-347 [-1, 1194, 17, 12] 0\n", + " ZeroPad2d-348 [-1, 1194, 17, 12] 0\n", + " Conv2d-349 [-1, 170, 17, 12] 202,980\n", + " BatchNorm2d-350 [-1, 170, 17, 12] 340\n", + " BasicConv2d-351 [-1, 170, 17, 12] 0\n", + " ZeroPad2d-352 [-1, 1194, 17, 12] 0\n", + " Conv2d-353 [-1, 256, 17, 12] 305,664\n", + " BatchNorm2d-354 [-1, 256, 17, 12] 512\n", + " BasicConv2d-355 [-1, 256, 17, 12] 0\n", + " ZeroPad2d-356 [-1, 256, 19, 14] 0\n", + " Conv2d-357 [-1, 512, 17, 12] 1,179,648\n", + " BatchNorm2d-358 [-1, 512, 17, 12] 1,024\n", + " BasicConv2d-359 [-1, 512, 17, 12] 0\n", + " ZeroPad2d-360 [-1, 1194, 17, 12] 0\n", + " Conv2d-361 [-1, 170, 17, 12] 202,980\n", + " BatchNorm2d-362 [-1, 170, 17, 12] 340\n", + " BasicConv2d-363 [-1, 170, 17, 12] 0\n", + " ZeroPad2d-364 [-1, 170, 19, 14] 0\n", + " Conv2d-365 [-1, 256, 17, 12] 391,680\n", + " BatchNorm2d-366 [-1, 256, 17, 12] 512\n", + " BasicConv2d-367 [-1, 256, 17, 12] 0\n", + " ZeroPad2d-368 [-1, 256, 19, 14] 0\n", + " Conv2d-369 [-1, 512, 17, 12] 1,179,648\n", + " BatchNorm2d-370 [-1, 512, 17, 12] 1,024\n", + " BasicConv2d-371 [-1, 512, 17, 12] 0\n", + " BatchNorm2d-372 [-1, 1194, 17, 12] 2,388\n", + "Multi_2D_CNN_block-373 [-1, 1194, 17, 12] 0\n", + " ZeroPad2d-374 [-1, 1194, 17, 12] 0\n", + " Conv2d-375 [-1, 170, 17, 12] 202,980\n", + " BatchNorm2d-376 [-1, 170, 17, 12] 340\n", + " BasicConv2d-377 [-1, 170, 17, 12] 0\n", + " ZeroPad2d-378 [-1, 1194, 17, 12] 0\n", + " Conv2d-379 [-1, 256, 17, 12] 305,664\n", + " BatchNorm2d-380 [-1, 256, 17, 12] 512\n", + " BasicConv2d-381 [-1, 256, 17, 12] 0\n", + " ZeroPad2d-382 [-1, 256, 19, 14] 0\n", + " Conv2d-383 [-1, 512, 17, 12] 1,179,648\n", + " BatchNorm2d-384 [-1, 512, 17, 12] 1,024\n", + " BasicConv2d-385 [-1, 512, 17, 12] 0\n", + " ZeroPad2d-386 [-1, 1194, 17, 12] 0\n", + " Conv2d-387 [-1, 170, 17, 12] 202,980\n", + " BatchNorm2d-388 [-1, 170, 17, 12] 340\n", + " BasicConv2d-389 [-1, 170, 17, 12] 0\n", + " ZeroPad2d-390 [-1, 170, 19, 14] 0\n", + " Conv2d-391 [-1, 256, 17, 12] 391,680\n", + " BatchNorm2d-392 [-1, 256, 17, 12] 512\n", + " BasicConv2d-393 [-1, 256, 17, 12] 0\n", + " ZeroPad2d-394 [-1, 256, 19, 14] 0\n", + " Conv2d-395 [-1, 512, 17, 12] 1,179,648\n", + " BatchNorm2d-396 [-1, 512, 17, 12] 1,024\n", + " BasicConv2d-397 [-1, 512, 17, 12] 0\n", + " BatchNorm2d-398 [-1, 1194, 17, 12] 2,388\n", + "Multi_2D_CNN_block-399 [-1, 1194, 17, 12] 0\n", + " MaxPool2d-400 [-1, 1194, 8, 12] 0\n", + " ZeroPad2d-401 [-1, 1194, 8, 12] 0\n", + " Conv2d-402 [-1, 256, 8, 12] 305,664\n", + " BatchNorm2d-403 [-1, 256, 8, 12] 512\n", + " BasicConv2d-404 [-1, 256, 8, 12] 0\n", + " ZeroPad2d-405 [-1, 1194, 8, 12] 0\n", + " Conv2d-406 [-1, 384, 8, 12] 458,496\n", + " BatchNorm2d-407 [-1, 384, 8, 12] 768\n", + " BasicConv2d-408 [-1, 384, 8, 12] 0\n", + " ZeroPad2d-409 [-1, 384, 10, 14] 0\n", + " Conv2d-410 [-1, 768, 8, 12] 2,654,208\n", + " BatchNorm2d-411 [-1, 768, 8, 12] 1,536\n", + " BasicConv2d-412 [-1, 768, 8, 12] 0\n", + " ZeroPad2d-413 [-1, 1194, 8, 12] 0\n", + " Conv2d-414 [-1, 256, 8, 12] 305,664\n", + " BatchNorm2d-415 [-1, 256, 8, 12] 512\n", + " BasicConv2d-416 [-1, 256, 8, 12] 0\n", + " ZeroPad2d-417 [-1, 256, 10, 14] 0\n", + " Conv2d-418 [-1, 384, 8, 12] 884,736\n", + " BatchNorm2d-419 [-1, 384, 8, 12] 768\n", + " BasicConv2d-420 [-1, 384, 8, 12] 0\n", + " ZeroPad2d-421 [-1, 384, 10, 14] 0\n", + " Conv2d-422 [-1, 768, 8, 12] 2,654,208\n", + " BatchNorm2d-423 [-1, 768, 8, 12] 1,536\n", + " BasicConv2d-424 [-1, 768, 8, 12] 0\n", + " BatchNorm2d-425 [-1, 1792, 8, 12] 3,584\n", + "Multi_2D_CNN_block-426 [-1, 1792, 8, 12] 0\n", + " ZeroPad2d-427 [-1, 1792, 8, 12] 0\n", + " Conv2d-428 [-1, 298, 8, 12] 534,016\n", + " BatchNorm2d-429 [-1, 298, 8, 12] 596\n", + " BasicConv2d-430 [-1, 298, 8, 12] 0\n", + " ZeroPad2d-431 [-1, 1792, 8, 12] 0\n", + " Conv2d-432 [-1, 448, 8, 12] 802,816\n", + " BatchNorm2d-433 [-1, 448, 8, 12] 896\n", + " BasicConv2d-434 [-1, 448, 8, 12] 0\n", + " ZeroPad2d-435 [-1, 448, 10, 14] 0\n", + " Conv2d-436 [-1, 896, 8, 12] 3,612,672\n", + " BatchNorm2d-437 [-1, 896, 8, 12] 1,792\n", + " BasicConv2d-438 [-1, 896, 8, 12] 0\n", + " ZeroPad2d-439 [-1, 1792, 8, 12] 0\n", + " Conv2d-440 [-1, 298, 8, 12] 534,016\n", + " BatchNorm2d-441 [-1, 298, 8, 12] 596\n", + " BasicConv2d-442 [-1, 298, 8, 12] 0\n", + " ZeroPad2d-443 [-1, 298, 10, 14] 0\n", + " Conv2d-444 [-1, 448, 8, 12] 1,201,536\n", + " BatchNorm2d-445 [-1, 448, 8, 12] 896\n", + " BasicConv2d-446 [-1, 448, 8, 12] 0\n", + " ZeroPad2d-447 [-1, 448, 10, 14] 0\n", + " Conv2d-448 [-1, 896, 8, 12] 3,612,672\n", + " BatchNorm2d-449 [-1, 896, 8, 12] 1,792\n", + " BasicConv2d-450 [-1, 896, 8, 12] 0\n", + " BatchNorm2d-451 [-1, 2090, 8, 12] 4,180\n", + "Multi_2D_CNN_block-452 [-1, 2090, 8, 12] 0\n", + " ZeroPad2d-453 [-1, 2090, 8, 12] 0\n", + " Conv2d-454 [-1, 341, 8, 12] 712,690\n", + " BatchNorm2d-455 [-1, 341, 8, 12] 682\n", + " BasicConv2d-456 [-1, 341, 8, 12] 0\n", + " ZeroPad2d-457 [-1, 2090, 8, 12] 0\n", + " Conv2d-458 [-1, 512, 8, 12] 1,070,080\n", + " BatchNorm2d-459 [-1, 512, 8, 12] 1,024\n", + " BasicConv2d-460 [-1, 512, 8, 12] 0\n", + " ZeroPad2d-461 [-1, 512, 10, 14] 0\n", + " Conv2d-462 [-1, 1024, 8, 12] 4,718,592\n", + " BatchNorm2d-463 [-1, 1024, 8, 12] 2,048\n", + " BasicConv2d-464 [-1, 1024, 8, 12] 0\n", + " ZeroPad2d-465 [-1, 2090, 8, 12] 0\n", + " Conv2d-466 [-1, 341, 8, 12] 712,690\n", + " BatchNorm2d-467 [-1, 341, 8, 12] 682\n", + " BasicConv2d-468 [-1, 341, 8, 12] 0\n", + " ZeroPad2d-469 [-1, 341, 10, 14] 0\n", + " Conv2d-470 [-1, 512, 8, 12] 1,571,328\n", + " BatchNorm2d-471 [-1, 512, 8, 12] 1,024\n", + " BasicConv2d-472 [-1, 512, 8, 12] 0\n", + " ZeroPad2d-473 [-1, 512, 10, 14] 0\n", + " Conv2d-474 [-1, 1024, 8, 12] 4,718,592\n", + " BatchNorm2d-475 [-1, 1024, 8, 12] 2,048\n", + " BasicConv2d-476 [-1, 1024, 8, 12] 0\n", + " BatchNorm2d-477 [-1, 2389, 8, 12] 4,778\n", + "Multi_2D_CNN_block-478 [-1, 2389, 8, 12] 0\n", + "AdaptiveAvgPool2d-479 [-1, 2389, 1, 1] 0\n", + " Flatten-480 [-1, 2389] 0\n", + " Dropout-481 [-1, 2389] 0\n", + " Linear-482 [-1, 1] 2,390\n", + "================================================================\n", + "Total params: 49,719,798\n", + "Trainable params: 49,719,798\n", + "Non-trainable params: 0\n", + "----------------------------------------------------------------\n", + "Input size (MB): 0.11\n", + "Forward/backward pass size (MB): 884.62\n", + "Params size (MB): 189.67\n", + "Estimated Total Size (MB): 1074.40\n", + "----------------------------------------------------------------\n" + ] + } + ], + "source": [ + "summary(MyModel(), input_size=(1, 2500, 12))" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------------------------------------\n", + " Layer (type) Output Shape Param #\n", + "================================================================\n", + " ZeroPad2d-1 [-1, 1, 2514, 12] 0\n", + " Conv2d-2 [-1, 64, 1250, 12] 960\n", + " BatchNorm2d-3 [-1, 64, 1250, 12] 128\n", + " BasicConv2d-4 [-1, 64, 1250, 12] 0\n", + " ZeroPad2d-5 [-1, 64, 1250, 12] 0\n", + " Conv2d-6 [-1, 21, 1250, 12] 1,344\n", + " BatchNorm2d-7 [-1, 21, 1250, 12] 42\n", + " BasicConv2d-8 [-1, 21, 1250, 12] 0\n", + " ZeroPad2d-9 [-1, 64, 1250, 12] 0\n", + " Conv2d-10 [-1, 32, 1250, 12] 2,048\n", + " BatchNorm2d-11 [-1, 32, 1250, 12] 64\n", + " BasicConv2d-12 [-1, 32, 1250, 12] 0\n", + " ZeroPad2d-13 [-1, 32, 1252, 14] 0\n", + " Conv2d-14 [-1, 64, 1250, 12] 18,432\n", + " BatchNorm2d-15 [-1, 64, 1250, 12] 128\n", + " BasicConv2d-16 [-1, 64, 1250, 12] 0\n", + " ZeroPad2d-17 [-1, 64, 1250, 12] 0\n", + " Conv2d-18 [-1, 21, 1250, 12] 1,344\n", + " BatchNorm2d-19 [-1, 21, 1250, 12] 42\n", + " BasicConv2d-20 [-1, 21, 1250, 12] 0\n", + " ZeroPad2d-21 [-1, 21, 1252, 14] 0\n", + " Conv2d-22 [-1, 32, 1250, 12] 6,048\n", + " BatchNorm2d-23 [-1, 32, 1250, 12] 64\n", + " BasicConv2d-24 [-1, 32, 1250, 12] 0\n", + " ZeroPad2d-25 [-1, 32, 1252, 14] 0\n", + " Conv2d-26 [-1, 64, 1250, 12] 18,432\n", + " BatchNorm2d-27 [-1, 64, 1250, 12] 128\n", + " BasicConv2d-28 [-1, 64, 1250, 12] 0\n", + " BatchNorm2d-29 [-1, 149, 1250, 12] 298\n", + "Multi_2D_CNN_block-30 [-1, 149, 1250, 12] 0\n", + " ZeroPad2d-31 [-1, 149, 1250, 12] 0\n", + " Conv2d-32 [-1, 21, 1250, 12] 3,129\n", + " BatchNorm2d-33 [-1, 21, 1250, 12] 42\n", + " BasicConv2d-34 [-1, 21, 1250, 12] 0\n", + " ZeroPad2d-35 [-1, 149, 1250, 12] 0\n", + " Conv2d-36 [-1, 32, 1250, 12] 4,768\n", + " BatchNorm2d-37 [-1, 32, 1250, 12] 64\n", + " BasicConv2d-38 [-1, 32, 1250, 12] 0\n", + " ZeroPad2d-39 [-1, 32, 1252, 14] 0\n", + " Conv2d-40 [-1, 64, 1250, 12] 18,432\n", + " BatchNorm2d-41 [-1, 64, 1250, 12] 128\n", + " BasicConv2d-42 [-1, 64, 1250, 12] 0\n", + " ZeroPad2d-43 [-1, 149, 1250, 12] 0\n", + " Conv2d-44 [-1, 21, 1250, 12] 3,129\n", + " BatchNorm2d-45 [-1, 21, 1250, 12] 42\n", + " BasicConv2d-46 [-1, 21, 1250, 12] 0\n", + " ZeroPad2d-47 [-1, 21, 1252, 14] 0\n", + " Conv2d-48 [-1, 32, 1250, 12] 6,048\n", + " BatchNorm2d-49 [-1, 32, 1250, 12] 64\n", + " BasicConv2d-50 [-1, 32, 1250, 12] 0\n", + " ZeroPad2d-51 [-1, 32, 1252, 14] 0\n", + " Conv2d-52 [-1, 64, 1250, 12] 18,432\n", + " BatchNorm2d-53 [-1, 64, 1250, 12] 128\n", + " BasicConv2d-54 [-1, 64, 1250, 12] 0\n", + " BatchNorm2d-55 [-1, 149, 1250, 12] 298\n", + "Multi_2D_CNN_block-56 [-1, 149, 1250, 12] 0\n", + " MaxPool2d-57 [-1, 149, 416, 12] 0\n", + " ZeroPad2d-58 [-1, 149, 416, 12] 0\n", + " Conv2d-59 [-1, 32, 416, 12] 4,768\n", + " BatchNorm2d-60 [-1, 32, 416, 12] 64\n", + " BasicConv2d-61 [-1, 32, 416, 12] 0\n", + " ZeroPad2d-62 [-1, 149, 416, 12] 0\n", + " Conv2d-63 [-1, 48, 416, 12] 7,152\n", + " BatchNorm2d-64 [-1, 48, 416, 12] 96\n", + " BasicConv2d-65 [-1, 48, 416, 12] 0\n", + " ZeroPad2d-66 [-1, 48, 418, 14] 0\n", + " Conv2d-67 [-1, 96, 416, 12] 41,472\n", + " BatchNorm2d-68 [-1, 96, 416, 12] 192\n", + " BasicConv2d-69 [-1, 96, 416, 12] 0\n", + " ZeroPad2d-70 [-1, 149, 416, 12] 0\n", + " Conv2d-71 [-1, 32, 416, 12] 4,768\n", + " BatchNorm2d-72 [-1, 32, 416, 12] 64\n", + " BasicConv2d-73 [-1, 32, 416, 12] 0\n", + " ZeroPad2d-74 [-1, 32, 418, 14] 0\n", + " Conv2d-75 [-1, 48, 416, 12] 13,824\n", + " BatchNorm2d-76 [-1, 48, 416, 12] 96\n", + " BasicConv2d-77 [-1, 48, 416, 12] 0\n", + " ZeroPad2d-78 [-1, 48, 418, 14] 0\n", + " Conv2d-79 [-1, 96, 416, 12] 41,472\n", + " BatchNorm2d-80 [-1, 96, 416, 12] 192\n", + " BasicConv2d-81 [-1, 96, 416, 12] 0\n", + " BatchNorm2d-82 [-1, 224, 416, 12] 448\n", + "Multi_2D_CNN_block-83 [-1, 224, 416, 12] 0\n", + " ZeroPad2d-84 [-1, 224, 416, 12] 0\n", + " Conv2d-85 [-1, 32, 416, 12] 7,168\n", + " BatchNorm2d-86 [-1, 32, 416, 12] 64\n", + " BasicConv2d-87 [-1, 32, 416, 12] 0\n", + " ZeroPad2d-88 [-1, 224, 416, 12] 0\n", + " Conv2d-89 [-1, 48, 416, 12] 10,752\n", + " BatchNorm2d-90 [-1, 48, 416, 12] 96\n", + " BasicConv2d-91 [-1, 48, 416, 12] 0\n", + " ZeroPad2d-92 [-1, 48, 418, 14] 0\n", + " Conv2d-93 [-1, 96, 416, 12] 41,472\n", + " BatchNorm2d-94 [-1, 96, 416, 12] 192\n", + " BasicConv2d-95 [-1, 96, 416, 12] 0\n", + " ZeroPad2d-96 [-1, 224, 416, 12] 0\n", + " Conv2d-97 [-1, 32, 416, 12] 7,168\n", + " BatchNorm2d-98 [-1, 32, 416, 12] 64\n", + " BasicConv2d-99 [-1, 32, 416, 12] 0\n", + " ZeroPad2d-100 [-1, 32, 418, 14] 0\n", + " Conv2d-101 [-1, 48, 416, 12] 13,824\n", + " BatchNorm2d-102 [-1, 48, 416, 12] 96\n", + " BasicConv2d-103 [-1, 48, 416, 12] 0\n", + " ZeroPad2d-104 [-1, 48, 418, 14] 0\n", + " Conv2d-105 [-1, 96, 416, 12] 41,472\n", + " BatchNorm2d-106 [-1, 96, 416, 12] 192\n", + " BasicConv2d-107 [-1, 96, 416, 12] 0\n", + " BatchNorm2d-108 [-1, 224, 416, 12] 448\n", + "Multi_2D_CNN_block-109 [-1, 224, 416, 12] 0\n", + " MaxPool2d-110 [-1, 224, 138, 12] 0\n", + " ZeroPad2d-111 [-1, 224, 138, 12] 0\n", + " Conv2d-112 [-1, 42, 138, 12] 9,408\n", + " BatchNorm2d-113 [-1, 42, 138, 12] 84\n", + " BasicConv2d-114 [-1, 42, 138, 12] 0\n", + " ZeroPad2d-115 [-1, 224, 138, 12] 0\n", + " Conv2d-116 [-1, 64, 138, 12] 14,336\n", + " BatchNorm2d-117 [-1, 64, 138, 12] 128\n", + " BasicConv2d-118 [-1, 64, 138, 12] 0\n", + " ZeroPad2d-119 [-1, 64, 140, 14] 0\n", + " Conv2d-120 [-1, 128, 138, 12] 73,728\n", + " BatchNorm2d-121 [-1, 128, 138, 12] 256\n", + " BasicConv2d-122 [-1, 128, 138, 12] 0\n", + " ZeroPad2d-123 [-1, 224, 138, 12] 0\n", + " Conv2d-124 [-1, 42, 138, 12] 9,408\n", + " BatchNorm2d-125 [-1, 42, 138, 12] 84\n", + " BasicConv2d-126 [-1, 42, 138, 12] 0\n", + " ZeroPad2d-127 [-1, 42, 140, 14] 0\n", + " Conv2d-128 [-1, 64, 138, 12] 24,192\n", + " BatchNorm2d-129 [-1, 64, 138, 12] 128\n", + " BasicConv2d-130 [-1, 64, 138, 12] 0\n", + " ZeroPad2d-131 [-1, 64, 140, 14] 0\n", + " Conv2d-132 [-1, 128, 138, 12] 73,728\n", + " BatchNorm2d-133 [-1, 128, 138, 12] 256\n", + " BasicConv2d-134 [-1, 128, 138, 12] 0\n", + " BatchNorm2d-135 [-1, 298, 138, 12] 596\n", + "Multi_2D_CNN_block-136 [-1, 298, 138, 12] 0\n", + " ZeroPad2d-137 [-1, 298, 138, 12] 0\n", + " Conv2d-138 [-1, 42, 138, 12] 12,516\n", + " BatchNorm2d-139 [-1, 42, 138, 12] 84\n", + " BasicConv2d-140 [-1, 42, 138, 12] 0\n", + " ZeroPad2d-141 [-1, 298, 138, 12] 0\n", + " Conv2d-142 [-1, 64, 138, 12] 19,072\n", + " BatchNorm2d-143 [-1, 64, 138, 12] 128\n", + " BasicConv2d-144 [-1, 64, 138, 12] 0\n", + " ZeroPad2d-145 [-1, 64, 140, 14] 0\n", + " Conv2d-146 [-1, 128, 138, 12] 73,728\n", + " BatchNorm2d-147 [-1, 128, 138, 12] 256\n", + " BasicConv2d-148 [-1, 128, 138, 12] 0\n", + " ZeroPad2d-149 [-1, 298, 138, 12] 0\n", + " Conv2d-150 [-1, 42, 138, 12] 12,516\n", + " BatchNorm2d-151 [-1, 42, 138, 12] 84\n", + " BasicConv2d-152 [-1, 42, 138, 12] 0\n", + " ZeroPad2d-153 [-1, 42, 140, 14] 0\n", + " Conv2d-154 [-1, 64, 138, 12] 24,192\n", + " BatchNorm2d-155 [-1, 64, 138, 12] 128\n", + " BasicConv2d-156 [-1, 64, 138, 12] 0\n", + " ZeroPad2d-157 [-1, 64, 140, 14] 0\n", + " Conv2d-158 [-1, 128, 138, 12] 73,728\n", + " BatchNorm2d-159 [-1, 128, 138, 12] 256\n", + " BasicConv2d-160 [-1, 128, 138, 12] 0\n", + " BatchNorm2d-161 [-1, 298, 138, 12] 596\n", + "Multi_2D_CNN_block-162 [-1, 298, 138, 12] 0\n", + " MaxPool2d-163 [-1, 298, 69, 12] 0\n", + " ZeroPad2d-164 [-1, 298, 69, 12] 0\n", + " Conv2d-165 [-1, 64, 69, 12] 19,072\n", + " BatchNorm2d-166 [-1, 64, 69, 12] 128\n", + " BasicConv2d-167 [-1, 64, 69, 12] 0\n", + " ZeroPad2d-168 [-1, 298, 69, 12] 0\n", + " Conv2d-169 [-1, 96, 69, 12] 28,608\n", + " BatchNorm2d-170 [-1, 96, 69, 12] 192\n", + " BasicConv2d-171 [-1, 96, 69, 12] 0\n", + " ZeroPad2d-172 [-1, 96, 71, 14] 0\n", + " Conv2d-173 [-1, 192, 69, 12] 165,888\n", + " BatchNorm2d-174 [-1, 192, 69, 12] 384\n", + " BasicConv2d-175 [-1, 192, 69, 12] 0\n", + " ZeroPad2d-176 [-1, 298, 69, 12] 0\n", + " Conv2d-177 [-1, 64, 69, 12] 19,072\n", + " BatchNorm2d-178 [-1, 64, 69, 12] 128\n", + " BasicConv2d-179 [-1, 64, 69, 12] 0\n", + " ZeroPad2d-180 [-1, 64, 71, 14] 0\n", + " Conv2d-181 [-1, 96, 69, 12] 55,296\n", + " BatchNorm2d-182 [-1, 96, 69, 12] 192\n", + " BasicConv2d-183 [-1, 96, 69, 12] 0\n", + " ZeroPad2d-184 [-1, 96, 71, 14] 0\n", + " Conv2d-185 [-1, 192, 69, 12] 165,888\n", + " BatchNorm2d-186 [-1, 192, 69, 12] 384\n", + " BasicConv2d-187 [-1, 192, 69, 12] 0\n", + " BatchNorm2d-188 [-1, 448, 69, 12] 896\n", + "Multi_2D_CNN_block-189 [-1, 448, 69, 12] 0\n", + " ZeroPad2d-190 [-1, 448, 69, 12] 0\n", + " Conv2d-191 [-1, 64, 69, 12] 28,672\n", + " BatchNorm2d-192 [-1, 64, 69, 12] 128\n", + " BasicConv2d-193 [-1, 64, 69, 12] 0\n", + " ZeroPad2d-194 [-1, 448, 69, 12] 0\n", + " Conv2d-195 [-1, 96, 69, 12] 43,008\n", + " BatchNorm2d-196 [-1, 96, 69, 12] 192\n", + " BasicConv2d-197 [-1, 96, 69, 12] 0\n", + " ZeroPad2d-198 [-1, 96, 71, 14] 0\n", + " Conv2d-199 [-1, 192, 69, 12] 165,888\n", + " BatchNorm2d-200 [-1, 192, 69, 12] 384\n", + " BasicConv2d-201 [-1, 192, 69, 12] 0\n", + " ZeroPad2d-202 [-1, 448, 69, 12] 0\n", + " Conv2d-203 [-1, 64, 69, 12] 28,672\n", + " BatchNorm2d-204 [-1, 64, 69, 12] 128\n", + " BasicConv2d-205 [-1, 64, 69, 12] 0\n", + " ZeroPad2d-206 [-1, 64, 71, 14] 0\n", + " Conv2d-207 [-1, 96, 69, 12] 55,296\n", + " BatchNorm2d-208 [-1, 96, 69, 12] 192\n", + " BasicConv2d-209 [-1, 96, 69, 12] 0\n", + " ZeroPad2d-210 [-1, 96, 71, 14] 0\n", + " Conv2d-211 [-1, 192, 69, 12] 165,888\n", + " BatchNorm2d-212 [-1, 192, 69, 12] 384\n", + " BasicConv2d-213 [-1, 192, 69, 12] 0\n", + " BatchNorm2d-214 [-1, 448, 69, 12] 896\n", + "Multi_2D_CNN_block-215 [-1, 448, 69, 12] 0\n", + " ZeroPad2d-216 [-1, 448, 69, 12] 0\n", + " Conv2d-217 [-1, 85, 69, 12] 38,080\n", + " BatchNorm2d-218 [-1, 85, 69, 12] 170\n", + " BasicConv2d-219 [-1, 85, 69, 12] 0\n", + " ZeroPad2d-220 [-1, 448, 69, 12] 0\n", + " Conv2d-221 [-1, 128, 69, 12] 57,344\n", + " BatchNorm2d-222 [-1, 128, 69, 12] 256\n", + " BasicConv2d-223 [-1, 128, 69, 12] 0\n", + " ZeroPad2d-224 [-1, 128, 71, 14] 0\n", + " Conv2d-225 [-1, 256, 69, 12] 294,912\n", + " BatchNorm2d-226 [-1, 256, 69, 12] 512\n", + " BasicConv2d-227 [-1, 256, 69, 12] 0\n", + " ZeroPad2d-228 [-1, 448, 69, 12] 0\n", + " Conv2d-229 [-1, 85, 69, 12] 38,080\n", + " BatchNorm2d-230 [-1, 85, 69, 12] 170\n", + " BasicConv2d-231 [-1, 85, 69, 12] 0\n", + " ZeroPad2d-232 [-1, 85, 71, 14] 0\n", + " Conv2d-233 [-1, 128, 69, 12] 97,920\n", + " BatchNorm2d-234 [-1, 128, 69, 12] 256\n", + " BasicConv2d-235 [-1, 128, 69, 12] 0\n", + " ZeroPad2d-236 [-1, 128, 71, 14] 0\n", + " Conv2d-237 [-1, 256, 69, 12] 294,912\n", + " BatchNorm2d-238 [-1, 256, 69, 12] 512\n", + " BasicConv2d-239 [-1, 256, 69, 12] 0\n", + " BatchNorm2d-240 [-1, 597, 69, 12] 1,194\n", + "Multi_2D_CNN_block-241 [-1, 597, 69, 12] 0\n", + " MaxPool2d-242 [-1, 597, 34, 12] 0\n", + " ZeroPad2d-243 [-1, 597, 34, 12] 0\n", + " Conv2d-244 [-1, 106, 34, 12] 63,282\n", + " BatchNorm2d-245 [-1, 106, 34, 12] 212\n", + " BasicConv2d-246 [-1, 106, 34, 12] 0\n", + " ZeroPad2d-247 [-1, 597, 34, 12] 0\n", + " Conv2d-248 [-1, 160, 34, 12] 95,520\n", + " BatchNorm2d-249 [-1, 160, 34, 12] 320\n", + " BasicConv2d-250 [-1, 160, 34, 12] 0\n", + " ZeroPad2d-251 [-1, 160, 36, 14] 0\n", + " Conv2d-252 [-1, 320, 34, 12] 460,800\n", + " BatchNorm2d-253 [-1, 320, 34, 12] 640\n", + " BasicConv2d-254 [-1, 320, 34, 12] 0\n", + " ZeroPad2d-255 [-1, 597, 34, 12] 0\n", + " Conv2d-256 [-1, 106, 34, 12] 63,282\n", + " BatchNorm2d-257 [-1, 106, 34, 12] 212\n", + " BasicConv2d-258 [-1, 106, 34, 12] 0\n", + " ZeroPad2d-259 [-1, 106, 36, 14] 0\n", + " Conv2d-260 [-1, 160, 34, 12] 152,640\n", + " BatchNorm2d-261 [-1, 160, 34, 12] 320\n", + " BasicConv2d-262 [-1, 160, 34, 12] 0\n", + " ZeroPad2d-263 [-1, 160, 36, 14] 0\n", + " Conv2d-264 [-1, 320, 34, 12] 460,800\n", + " BatchNorm2d-265 [-1, 320, 34, 12] 640\n", + " BasicConv2d-266 [-1, 320, 34, 12] 0\n", + " BatchNorm2d-267 [-1, 746, 34, 12] 1,492\n", + "Multi_2D_CNN_block-268 [-1, 746, 34, 12] 0\n", + " ZeroPad2d-269 [-1, 746, 34, 12] 0\n", + " Conv2d-270 [-1, 128, 34, 12] 95,488\n", + " BatchNorm2d-271 [-1, 128, 34, 12] 256\n", + " BasicConv2d-272 [-1, 128, 34, 12] 0\n", + " ZeroPad2d-273 [-1, 746, 34, 12] 0\n", + " Conv2d-274 [-1, 192, 34, 12] 143,232\n", + " BatchNorm2d-275 [-1, 192, 34, 12] 384\n", + " BasicConv2d-276 [-1, 192, 34, 12] 0\n", + " ZeroPad2d-277 [-1, 192, 36, 14] 0\n", + " Conv2d-278 [-1, 384, 34, 12] 663,552\n", + " BatchNorm2d-279 [-1, 384, 34, 12] 768\n", + " BasicConv2d-280 [-1, 384, 34, 12] 0\n", + " ZeroPad2d-281 [-1, 746, 34, 12] 0\n", + " Conv2d-282 [-1, 128, 34, 12] 95,488\n", + " BatchNorm2d-283 [-1, 128, 34, 12] 256\n", + " BasicConv2d-284 [-1, 128, 34, 12] 0\n", + " ZeroPad2d-285 [-1, 128, 36, 14] 0\n", + " Conv2d-286 [-1, 192, 34, 12] 221,184\n", + " BatchNorm2d-287 [-1, 192, 34, 12] 384\n", + " BasicConv2d-288 [-1, 192, 34, 12] 0\n", + " ZeroPad2d-289 [-1, 192, 36, 14] 0\n", + " Conv2d-290 [-1, 384, 34, 12] 663,552\n", + " BatchNorm2d-291 [-1, 384, 34, 12] 768\n", + " BasicConv2d-292 [-1, 384, 34, 12] 0\n", + " BatchNorm2d-293 [-1, 896, 34, 12] 1,792\n", + "Multi_2D_CNN_block-294 [-1, 896, 34, 12] 0\n", + " ZeroPad2d-295 [-1, 896, 34, 12] 0\n", + " Conv2d-296 [-1, 149, 34, 12] 133,504\n", + " BatchNorm2d-297 [-1, 149, 34, 12] 298\n", + " BasicConv2d-298 [-1, 149, 34, 12] 0\n", + " ZeroPad2d-299 [-1, 896, 34, 12] 0\n", + " Conv2d-300 [-1, 224, 34, 12] 200,704\n", + " BatchNorm2d-301 [-1, 224, 34, 12] 448\n", + " BasicConv2d-302 [-1, 224, 34, 12] 0\n", + " ZeroPad2d-303 [-1, 224, 36, 14] 0\n", + " Conv2d-304 [-1, 448, 34, 12] 903,168\n", + " BatchNorm2d-305 [-1, 448, 34, 12] 896\n", + " BasicConv2d-306 [-1, 448, 34, 12] 0\n", + " ZeroPad2d-307 [-1, 896, 34, 12] 0\n", + " Conv2d-308 [-1, 149, 34, 12] 133,504\n", + " BatchNorm2d-309 [-1, 149, 34, 12] 298\n", + " BasicConv2d-310 [-1, 149, 34, 12] 0\n", + " ZeroPad2d-311 [-1, 149, 36, 14] 0\n", + " Conv2d-312 [-1, 224, 34, 12] 300,384\n", + " BatchNorm2d-313 [-1, 224, 34, 12] 448\n", + " BasicConv2d-314 [-1, 224, 34, 12] 0\n", + " ZeroPad2d-315 [-1, 224, 36, 14] 0\n", + " Conv2d-316 [-1, 448, 34, 12] 903,168\n", + " BatchNorm2d-317 [-1, 448, 34, 12] 896\n", + " BasicConv2d-318 [-1, 448, 34, 12] 0\n", + " BatchNorm2d-319 [-1, 1045, 34, 12] 2,090\n", + "Multi_2D_CNN_block-320 [-1, 1045, 34, 12] 0\n", + " MaxPool2d-321 [-1, 1045, 17, 12] 0\n", + " ZeroPad2d-322 [-1, 1045, 17, 12] 0\n", + " Conv2d-323 [-1, 170, 17, 12] 177,650\n", + " BatchNorm2d-324 [-1, 170, 17, 12] 340\n", + " BasicConv2d-325 [-1, 170, 17, 12] 0\n", + " ZeroPad2d-326 [-1, 1045, 17, 12] 0\n", + " Conv2d-327 [-1, 256, 17, 12] 267,520\n", + " BatchNorm2d-328 [-1, 256, 17, 12] 512\n", + " BasicConv2d-329 [-1, 256, 17, 12] 0\n", + " ZeroPad2d-330 [-1, 256, 19, 14] 0\n", + " Conv2d-331 [-1, 512, 17, 12] 1,179,648\n", + " BatchNorm2d-332 [-1, 512, 17, 12] 1,024\n", + " BasicConv2d-333 [-1, 512, 17, 12] 0\n", + " ZeroPad2d-334 [-1, 1045, 17, 12] 0\n", + " Conv2d-335 [-1, 170, 17, 12] 177,650\n", + " BatchNorm2d-336 [-1, 170, 17, 12] 340\n", + " BasicConv2d-337 [-1, 170, 17, 12] 0\n", + " ZeroPad2d-338 [-1, 170, 19, 14] 0\n", + " Conv2d-339 [-1, 256, 17, 12] 391,680\n", + " BatchNorm2d-340 [-1, 256, 17, 12] 512\n", + " BasicConv2d-341 [-1, 256, 17, 12] 0\n", + " ZeroPad2d-342 [-1, 256, 19, 14] 0\n", + " Conv2d-343 [-1, 512, 17, 12] 1,179,648\n", + " BatchNorm2d-344 [-1, 512, 17, 12] 1,024\n", + " BasicConv2d-345 [-1, 512, 17, 12] 0\n", + " BatchNorm2d-346 [-1, 1194, 17, 12] 2,388\n", + "Multi_2D_CNN_block-347 [-1, 1194, 17, 12] 0\n", + " ZeroPad2d-348 [-1, 1194, 17, 12] 0\n", + " Conv2d-349 [-1, 170, 17, 12] 202,980\n", + " BatchNorm2d-350 [-1, 170, 17, 12] 340\n", + " BasicConv2d-351 [-1, 170, 17, 12] 0\n", + " ZeroPad2d-352 [-1, 1194, 17, 12] 0\n", + " Conv2d-353 [-1, 256, 17, 12] 305,664\n", + " BatchNorm2d-354 [-1, 256, 17, 12] 512\n", + " BasicConv2d-355 [-1, 256, 17, 12] 0\n", + " ZeroPad2d-356 [-1, 256, 19, 14] 0\n", + " Conv2d-357 [-1, 512, 17, 12] 1,179,648\n", + " BatchNorm2d-358 [-1, 512, 17, 12] 1,024\n", + " BasicConv2d-359 [-1, 512, 17, 12] 0\n", + " ZeroPad2d-360 [-1, 1194, 17, 12] 0\n", + " Conv2d-361 [-1, 170, 17, 12] 202,980\n", + " BatchNorm2d-362 [-1, 170, 17, 12] 340\n", + " BasicConv2d-363 [-1, 170, 17, 12] 0\n", + " ZeroPad2d-364 [-1, 170, 19, 14] 0\n", + " Conv2d-365 [-1, 256, 17, 12] 391,680\n", + " BatchNorm2d-366 [-1, 256, 17, 12] 512\n", + " BasicConv2d-367 [-1, 256, 17, 12] 0\n", + " ZeroPad2d-368 [-1, 256, 19, 14] 0\n", + " Conv2d-369 [-1, 512, 17, 12] 1,179,648\n", + " BatchNorm2d-370 [-1, 512, 17, 12] 1,024\n", + " BasicConv2d-371 [-1, 512, 17, 12] 0\n", + " BatchNorm2d-372 [-1, 1194, 17, 12] 2,388\n", + "Multi_2D_CNN_block-373 [-1, 1194, 17, 12] 0\n", + " ZeroPad2d-374 [-1, 1194, 17, 12] 0\n", + " Conv2d-375 [-1, 170, 17, 12] 202,980\n", + " BatchNorm2d-376 [-1, 170, 17, 12] 340\n", + " BasicConv2d-377 [-1, 170, 17, 12] 0\n", + " ZeroPad2d-378 [-1, 1194, 17, 12] 0\n", + " Conv2d-379 [-1, 256, 17, 12] 305,664\n", + " BatchNorm2d-380 [-1, 256, 17, 12] 512\n", + " BasicConv2d-381 [-1, 256, 17, 12] 0\n", + " ZeroPad2d-382 [-1, 256, 19, 14] 0\n", + " Conv2d-383 [-1, 512, 17, 12] 1,179,648\n", + " BatchNorm2d-384 [-1, 512, 17, 12] 1,024\n", + " BasicConv2d-385 [-1, 512, 17, 12] 0\n", + " ZeroPad2d-386 [-1, 1194, 17, 12] 0\n", + " Conv2d-387 [-1, 170, 17, 12] 202,980\n", + " BatchNorm2d-388 [-1, 170, 17, 12] 340\n", + " BasicConv2d-389 [-1, 170, 17, 12] 0\n", + " ZeroPad2d-390 [-1, 170, 19, 14] 0\n", + " Conv2d-391 [-1, 256, 17, 12] 391,680\n", + " BatchNorm2d-392 [-1, 256, 17, 12] 512\n", + " BasicConv2d-393 [-1, 256, 17, 12] 0\n", + " ZeroPad2d-394 [-1, 256, 19, 14] 0\n", + " Conv2d-395 [-1, 512, 17, 12] 1,179,648\n", + " BatchNorm2d-396 [-1, 512, 17, 12] 1,024\n", + " BasicConv2d-397 [-1, 512, 17, 12] 0\n", + " BatchNorm2d-398 [-1, 1194, 17, 12] 2,388\n", + "Multi_2D_CNN_block-399 [-1, 1194, 17, 12] 0\n", + " MaxPool2d-400 [-1, 1194, 8, 12] 0\n", + " ZeroPad2d-401 [-1, 1194, 8, 12] 0\n", + " Conv2d-402 [-1, 256, 8, 12] 305,664\n", + " BatchNorm2d-403 [-1, 256, 8, 12] 512\n", + " BasicConv2d-404 [-1, 256, 8, 12] 0\n", + " ZeroPad2d-405 [-1, 1194, 8, 12] 0\n", + " Conv2d-406 [-1, 384, 8, 12] 458,496\n", + " BatchNorm2d-407 [-1, 384, 8, 12] 768\n", + " BasicConv2d-408 [-1, 384, 8, 12] 0\n", + " ZeroPad2d-409 [-1, 384, 10, 14] 0\n", + " Conv2d-410 [-1, 768, 8, 12] 2,654,208\n", + " BatchNorm2d-411 [-1, 768, 8, 12] 1,536\n", + " BasicConv2d-412 [-1, 768, 8, 12] 0\n", + " ZeroPad2d-413 [-1, 1194, 8, 12] 0\n", + " Conv2d-414 [-1, 256, 8, 12] 305,664\n", + " BatchNorm2d-415 [-1, 256, 8, 12] 512\n", + " BasicConv2d-416 [-1, 256, 8, 12] 0\n", + " ZeroPad2d-417 [-1, 256, 10, 14] 0\n", + " Conv2d-418 [-1, 384, 8, 12] 884,736\n", + " BatchNorm2d-419 [-1, 384, 8, 12] 768\n", + " BasicConv2d-420 [-1, 384, 8, 12] 0\n", + " ZeroPad2d-421 [-1, 384, 10, 14] 0\n", + " Conv2d-422 [-1, 768, 8, 12] 2,654,208\n", + " BatchNorm2d-423 [-1, 768, 8, 12] 1,536\n", + " BasicConv2d-424 [-1, 768, 8, 12] 0\n", + " BatchNorm2d-425 [-1, 1792, 8, 12] 3,584\n", + "Multi_2D_CNN_block-426 [-1, 1792, 8, 12] 0\n", + " ZeroPad2d-427 [-1, 1792, 8, 12] 0\n", + " Conv2d-428 [-1, 298, 8, 12] 534,016\n", + " BatchNorm2d-429 [-1, 298, 8, 12] 596\n", + " BasicConv2d-430 [-1, 298, 8, 12] 0\n", + " ZeroPad2d-431 [-1, 1792, 8, 12] 0\n", + " Conv2d-432 [-1, 448, 8, 12] 802,816\n", + " BatchNorm2d-433 [-1, 448, 8, 12] 896\n", + " BasicConv2d-434 [-1, 448, 8, 12] 0\n", + " ZeroPad2d-435 [-1, 448, 10, 14] 0\n", + " Conv2d-436 [-1, 896, 8, 12] 3,612,672\n", + " BatchNorm2d-437 [-1, 896, 8, 12] 1,792\n", + " BasicConv2d-438 [-1, 896, 8, 12] 0\n", + " ZeroPad2d-439 [-1, 1792, 8, 12] 0\n", + " Conv2d-440 [-1, 298, 8, 12] 534,016\n", + " BatchNorm2d-441 [-1, 298, 8, 12] 596\n", + " BasicConv2d-442 [-1, 298, 8, 12] 0\n", + " ZeroPad2d-443 [-1, 298, 10, 14] 0\n", + " Conv2d-444 [-1, 448, 8, 12] 1,201,536\n", + " BatchNorm2d-445 [-1, 448, 8, 12] 896\n", + " BasicConv2d-446 [-1, 448, 8, 12] 0\n", + " ZeroPad2d-447 [-1, 448, 10, 14] 0\n", + " Conv2d-448 [-1, 896, 8, 12] 3,612,672\n", + " BatchNorm2d-449 [-1, 896, 8, 12] 1,792\n", + " BasicConv2d-450 [-1, 896, 8, 12] 0\n", + " BatchNorm2d-451 [-1, 2090, 8, 12] 4,180\n", + "Multi_2D_CNN_block-452 [-1, 2090, 8, 12] 0\n", + " ZeroPad2d-453 [-1, 2090, 8, 12] 0\n", + " Conv2d-454 [-1, 341, 8, 12] 712,690\n", + " BatchNorm2d-455 [-1, 341, 8, 12] 682\n", + " BasicConv2d-456 [-1, 341, 8, 12] 0\n", + " ZeroPad2d-457 [-1, 2090, 8, 12] 0\n", + " Conv2d-458 [-1, 512, 8, 12] 1,070,080\n", + " BatchNorm2d-459 [-1, 512, 8, 12] 1,024\n", + " BasicConv2d-460 [-1, 512, 8, 12] 0\n", + " ZeroPad2d-461 [-1, 512, 10, 14] 0\n", + " Conv2d-462 [-1, 1024, 8, 12] 4,718,592\n", + " BatchNorm2d-463 [-1, 1024, 8, 12] 2,048\n", + " BasicConv2d-464 [-1, 1024, 8, 12] 0\n", + " ZeroPad2d-465 [-1, 2090, 8, 12] 0\n", + " Conv2d-466 [-1, 341, 8, 12] 712,690\n", + " BatchNorm2d-467 [-1, 341, 8, 12] 682\n", + " BasicConv2d-468 [-1, 341, 8, 12] 0\n", + " ZeroPad2d-469 [-1, 341, 10, 14] 0\n", + " Conv2d-470 [-1, 512, 8, 12] 1,571,328\n", + " BatchNorm2d-471 [-1, 512, 8, 12] 1,024\n", + " BasicConv2d-472 [-1, 512, 8, 12] 0\n", + " ZeroPad2d-473 [-1, 512, 10, 14] 0\n", + " Conv2d-474 [-1, 1024, 8, 12] 4,718,592\n", + " BatchNorm2d-475 [-1, 1024, 8, 12] 2,048\n", + " BasicConv2d-476 [-1, 1024, 8, 12] 0\n", + " BatchNorm2d-477 [-1, 2389, 8, 12] 4,778\n", + "Multi_2D_CNN_block-478 [-1, 2389, 8, 12] 0\n", + "AdaptiveAvgPool2d-479 [-1, 2389, 1, 1] 0\n", + " Flatten-480 [-1, 2389] 0\n", + " Dropout-481 [-1, 2389] 0\n", + " Linear-482 [-1, 1] 2,390\n", + "================================================================\n", + "Total params: 49,719,414\n", + "Trainable params: 49,719,414\n", + "Non-trainable params: 0\n", + "----------------------------------------------------------------\n", + "Input size (MB): 0.11\n", + "Forward/backward pass size (MB): 884.58\n", + "Params size (MB): 189.66\n", + "Estimated Total Size (MB): 1074.36\n", + "----------------------------------------------------------------\n" + ] + } + ], + "source": [ + "summary(Inception_Tabular_15_1(), input_size=(1, 2500, 12))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 89, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Net(\n", + " (conv1): Sequential(\n", + " (0): Conv2d(1, 64, kernel_size=(7, 3), stride=(1, 1))\n", + " (1): ReLU()\n", + " (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (4): ReLU()\n", + " (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (6): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), padding=0, dilation=1, ceil_mode=False)\n", + " )\n", + " (conv2): Sequential(\n", + " (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): ReLU()\n", + " (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (4): ReLU()\n", + " (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (6): MaxPool2d(kernel_size=(3, 1), stride=(1, 1), padding=0, dilation=1, ceil_mode=False)\n", + " )\n", + " (conv3): Sequential(\n", + " (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): ReLU()\n", + " (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (4): ReLU()\n", + " (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (6): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), padding=0, dilation=1, ceil_mode=False)\n", + " )\n", + " (conv4): Sequential(\n", + " (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): ReLU()\n", + " (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (4): ReLU()\n", + " (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (6): MaxPool2d(kernel_size=(3, 1), stride=(1, 1), padding=0, dilation=1, ceil_mode=False)\n", + " )\n", + " (conv5): Sequential(\n", + " (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): ReLU()\n", + " (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (4): ReLU()\n", + " (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (6): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), padding=0, dilation=1, ceil_mode=False)\n", + " )\n", + " (conv6): Sequential(\n", + " (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): ReLU()\n", + " (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (4): ReLU()\n", + " (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (6): MaxPool2d(kernel_size=(3, 1), stride=(1, 1), padding=0, dilation=1, ceil_mode=False)\n", + " )\n", + " (conv7): Sequential(\n", + " (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): ReLU()\n", + " (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (4): ReLU()\n", + " (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (6): MaxPool2d(kernel_size=(3, 1), stride=(1, 1), padding=0, dilation=1, ceil_mode=False)\n", + " )\n", + " (fc1): Sequential(\n", + " (0): Flatten()\n", + " (1): Linear(in_features=55680, out_features=64, bias=True)\n", + " (2): ReLU()\n", + " (3): Linear(in_features=64, out_features=32, bias=True)\n", + " (4): ReLU()\n", + " (5): Linear(in_features=32, out_features=1, bias=True)\n", + " )\n", + ")\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from torch.utils.data import TensorDataset\n", + "import torch.optim as optim\n", + "from torch.optim import lr_scheduler\n", + "import numpy as np\n", + "import torchvision\n", + "import torch.nn.functional as F\n", + "from torch.utils.data.sampler import SubsetRandomSampler\n", + "from torch.utils.data import DataLoader\n", + "from torchvision import datasets, models, transforms\n", + "from torchvision.transforms import Resize, ToTensor, Normalize\n", + "import matplotlib.pyplot as plt\n", + "from imblearn.under_sampling import RandomUnderSampler\n", + "\n", + "from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, roc_auc_score, \\\n", + " average_precision_score\n", + "from sklearn.model_selection import train_test_split\n", + "import time\n", + "import os\n", + "from pathlib import Path\n", + "from skimage import io\n", + "import copy\n", + "from torch import optim, cuda\n", + "import pandas as pd\n", + "import glob\n", + "from collections import Counter\n", + "# Useful for examining network\n", + "from functools import reduce\n", + "from operator import __add__\n", + "# from torchsummary import summary\n", + "import seaborn as sns\n", + "import warnings\n", + "# warnings.filterwarnings('ignore', category=FutureWarning)\n", + "from PIL import Image\n", + "from timeit import default_timer as timer\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# Useful for examining network\n", + "from functools import reduce\n", + "from operator import __add__\n", + "from torchsummary import summary\n", + "\n", + "# from IPython.core.interactiveshell import InteractiveShell\n", + "import seaborn as sns\n", + "\n", + "import warnings\n", + "# warnings.filterwarnings('ignore', category=FutureWarning)\n", + "\n", + "# Image manipulations\n", + "from PIL import Image\n", + "\n", + "# Timing utility\n", + "from timeit import default_timer as timer\n", + "\n", + "# Visualizations\n", + "import matplotlib.pyplot as plt\n", + "class Net(nn.Module):\n", + "\n", + " # Convolution as a whole should span at least 1 beat, preferably more\n", + " # Input shape = (1,2500,12)\n", + "# summary(model, input_size=(1, 2500, 12))\n", + " # CLEAR EXPERIMENTS TO TRY\n", + "# Alter kernel size, right now drops to 10 channels at beginning then stays there\n", + "# Try increasing output channel size as you go deeper \n", + "# Alter stride to have larger image at FC layer\n", + "\n", + " def __init__(self):\n", + " super(Net, self).__init__()\n", + "\n", + " base_conv = 64\n", + " self.conv1 = nn.Sequential( \n", + " nn.Conv2d(in_channels=1, out_channels=base_conv, kernel_size=(7,3), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.MaxPool2d(kernel_size=(3,1), stride=(3,1))\n", + " )\n", + "\n", + " self.conv2 = nn.Sequential(\n", + " nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.MaxPool2d(kernel_size=(3,1), stride=(1,1))\n", + " )\n", + "\n", + " self.conv3 = nn.Sequential(\n", + " nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.MaxPool2d(kernel_size=(3,1), stride=(3,1))\n", + " )\n", + "\n", + " self.conv4 = nn.Sequential(\n", + " nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.MaxPool2d(kernel_size=(3,1), stride=(1,1))\n", + " )\n", + "\n", + " self.conv5 = nn.Sequential(\n", + " nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.MaxPool2d(kernel_size=(3,1), stride=(3,1))\n", + " )\n", + " self.conv6 = nn.Sequential(\n", + " nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.MaxPool2d(kernel_size=(3,1), stride=(1,1))\n", + " )\n", + "\n", + " self.conv7 = nn.Sequential(\n", + " nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.MaxPool2d(kernel_size=(3,1), stride=(1,1))\n", + " )\n", + "\n", + " self.fc1 = nn.Sequential(\n", + "# nn.AdaptiveAvgPool2d((1, 1)),\n", + " nn.Flatten(),\n", + " nn.Linear(64*87*10, 64), #64 kernel size, 2500 pooled to 29, \n", + " nn.ReLU(),\n", + " nn.Linear(64, 32),\n", + " nn.ReLU(),\n", + " nn.Linear(32, 1))\n", + "\n", + " def forward(self, x):\n", + " out = self.conv1(x)\n", + "# print(out.shape)\n", + " out = self.conv2(out)\n", + "# print(out.shape) \n", + " out = self.conv3(out)\n", + "# print(out.shape)\n", + " out = self.conv4(out)\n", + "# print(out.shape)\n", + " out = self.conv5(out)\n", + "# print(out.shape)\n", + " out = self.conv6(out)\n", + "# print(out.shape)\n", + " out = self.conv7(out)\n", + "# print(out.shape)\n", + " out = self.fc1(out)\n", + " return out\n", + "\n", + "model = Net()\n", + "print(model)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 90, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([2, 64, 831, 10])\n", + "torch.Size([2, 64, 829, 10])\n", + "torch.Size([2, 64, 276, 10])\n", + "torch.Size([2, 64, 274, 10])\n", + "torch.Size([2, 64, 91, 10])\n", + "torch.Size([2, 64, 89, 10])\n", + "torch.Size([2, 64, 87, 10])\n", + "----------------------------------------------------------------\n", + " Layer (type) Output Shape Param #\n", + "================================================================\n", + " Conv2d-1 [-1, 64, 2494, 10] 1,408\n", + " ReLU-2 [-1, 64, 2494, 10] 0\n", + " BatchNorm2d-3 [-1, 64, 2494, 10] 128\n", + " Conv2d-4 [-1, 64, 2494, 10] 4,160\n", + " ReLU-5 [-1, 64, 2494, 10] 0\n", + " BatchNorm2d-6 [-1, 64, 2494, 10] 128\n", + " MaxPool2d-7 [-1, 64, 831, 10] 0\n", + " Conv2d-8 [-1, 64, 831, 10] 4,160\n", + " ReLU-9 [-1, 64, 831, 10] 0\n", + " BatchNorm2d-10 [-1, 64, 831, 10] 128\n", + " Conv2d-11 [-1, 64, 831, 10] 4,160\n", + " ReLU-12 [-1, 64, 831, 10] 0\n", + " BatchNorm2d-13 [-1, 64, 831, 10] 128\n", + " MaxPool2d-14 [-1, 64, 829, 10] 0\n", + " Conv2d-15 [-1, 64, 829, 10] 4,160\n", + " ReLU-16 [-1, 64, 829, 10] 0\n", + " BatchNorm2d-17 [-1, 64, 829, 10] 128\n", + " Conv2d-18 [-1, 64, 829, 10] 4,160\n", + " ReLU-19 [-1, 64, 829, 10] 0\n", + " BatchNorm2d-20 [-1, 64, 829, 10] 128\n", + " MaxPool2d-21 [-1, 64, 276, 10] 0\n", + " Conv2d-22 [-1, 64, 276, 10] 4,160\n", + " ReLU-23 [-1, 64, 276, 10] 0\n", + " BatchNorm2d-24 [-1, 64, 276, 10] 128\n", + " Conv2d-25 [-1, 64, 276, 10] 4,160\n", + " ReLU-26 [-1, 64, 276, 10] 0\n", + " BatchNorm2d-27 [-1, 64, 276, 10] 128\n", + " MaxPool2d-28 [-1, 64, 274, 10] 0\n", + " Conv2d-29 [-1, 64, 274, 10] 4,160\n", + " ReLU-30 [-1, 64, 274, 10] 0\n", + " BatchNorm2d-31 [-1, 64, 274, 10] 128\n", + " Conv2d-32 [-1, 64, 274, 10] 4,160\n", + " ReLU-33 [-1, 64, 274, 10] 0\n", + " BatchNorm2d-34 [-1, 64, 274, 10] 128\n", + " MaxPool2d-35 [-1, 64, 91, 10] 0\n", + " Conv2d-36 [-1, 64, 91, 10] 4,160\n", + " ReLU-37 [-1, 64, 91, 10] 0\n", + " BatchNorm2d-38 [-1, 64, 91, 10] 128\n", + " Conv2d-39 [-1, 64, 91, 10] 4,160\n", + " ReLU-40 [-1, 64, 91, 10] 0\n", + " BatchNorm2d-41 [-1, 64, 91, 10] 128\n", + " MaxPool2d-42 [-1, 64, 89, 10] 0\n", + " Conv2d-43 [-1, 64, 89, 10] 4,160\n", + " ReLU-44 [-1, 64, 89, 10] 0\n", + " BatchNorm2d-45 [-1, 64, 89, 10] 128\n", + " Conv2d-46 [-1, 64, 89, 10] 4,160\n", + " ReLU-47 [-1, 64, 89, 10] 0\n", + " BatchNorm2d-48 [-1, 64, 89, 10] 128\n", + " MaxPool2d-49 [-1, 64, 87, 10] 0\n", + " Flatten-50 [-1, 55680] 0\n", + " Linear-51 [-1, 64] 3,563,584\n", + " ReLU-52 [-1, 64] 0\n", + " Linear-53 [-1, 32] 2,080\n", + " ReLU-54 [-1, 32] 0\n", + " Linear-55 [-1, 1] 33\n", + "================================================================\n", + "Total params: 3,622,977\n", + "Trainable params: 3,622,977\n", + "Non-trainable params: 0\n", + "----------------------------------------------------------------\n", + "Input size (MB): 0.11\n", + "Forward/backward pass size (MB): 155.61\n", + "Params size (MB): 13.82\n", + "Estimated Total Size (MB): 169.54\n", + "----------------------------------------------------------------\n" + ] + } + ], + "source": [ + "summary(model, input_size=(1, 2500, 12))" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Net(\n", + " (conv1): Sequential(\n", + " (0): Conv2d(1, 64, kernel_size=(7, 3), stride=(1, 1))\n", + " (1): ReLU()\n", + " (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (4): ReLU()\n", + " (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (6): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), padding=0, dilation=1, ceil_mode=False)\n", + " )\n", + " (conv2): Sequential(\n", + " (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): ReLU()\n", + " (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (4): ReLU()\n", + " (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (6): MaxPool2d(kernel_size=(3, 1), stride=(1, 1), padding=0, dilation=1, ceil_mode=False)\n", + " )\n", + " (conv3): Sequential(\n", + " (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): ReLU()\n", + " (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (4): ReLU()\n", + " (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (6): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), padding=0, dilation=1, ceil_mode=False)\n", + " )\n", + " (conv4): Sequential(\n", + " (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): ReLU()\n", + " (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (4): ReLU()\n", + " (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (6): MaxPool2d(kernel_size=(3, 1), stride=(1, 1), padding=0, dilation=1, ceil_mode=False)\n", + " )\n", + " (conv5): Sequential(\n", + " (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): ReLU()\n", + " (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (4): ReLU()\n", + " (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (6): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), padding=0, dilation=1, ceil_mode=False)\n", + " )\n", + " (conv6): Sequential(\n", + " (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): ReLU()\n", + " (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (4): ReLU()\n", + " (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (6): MaxPool2d(kernel_size=(3, 1), stride=(1, 1), padding=0, dilation=1, ceil_mode=False)\n", + " )\n", + " (conv7): Sequential(\n", + " (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): ReLU()\n", + " (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (4): ReLU()\n", + " (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (6): MaxPool2d(kernel_size=(3, 1), stride=(1, 1), padding=0, dilation=1, ceil_mode=False)\n", + " )\n", + " (fc1): Sequential(\n", + " (0): Flatten()\n", + " (1): Linear(in_features=55680, out_features=64, bias=True)\n", + " (2): ReLU()\n", + " (3): Linear(in_features=64, out_features=32, bias=True)\n", + " (4): ReLU()\n", + " (5): Linear(in_features=32, out_features=1, bias=True)\n", + " )\n", + ")\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from torch.utils.data import TensorDataset\n", + "import torch.optim as optim\n", + "from torch.optim import lr_scheduler\n", + "import numpy as np\n", + "import torchvision\n", + "import torch.nn.functional as F\n", + "from torch.utils.data.sampler import SubsetRandomSampler\n", + "from torch.utils.data import DataLoader\n", + "from torchvision import datasets, models, transforms\n", + "from torchvision.transforms import Resize, ToTensor, Normalize\n", + "import matplotlib.pyplot as plt\n", + "from imblearn.under_sampling import RandomUnderSampler\n", + "\n", + "from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, roc_auc_score, \\\n", + " average_precision_score\n", + "from sklearn.model_selection import train_test_split\n", + "import time\n", + "import os\n", + "from pathlib import Path\n", + "from skimage import io\n", + "import copy\n", + "from torch import optim, cuda\n", + "import pandas as pd\n", + "import glob\n", + "from collections import Counter\n", + "# Useful for examining network\n", + "from functools import reduce\n", + "from operator import __add__\n", + "# from torchsummary import summary\n", + "import seaborn as sns\n", + "import warnings\n", + "# warnings.filterwarnings('ignore', category=FutureWarning)\n", + "from PIL import Image\n", + "from timeit import default_timer as timer\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# Useful for examining network\n", + "from functools import reduce\n", + "from operator import __add__\n", + "from torchsummary import summary\n", + "\n", + "# from IPython.core.interactiveshell import InteractiveShell\n", + "import seaborn as sns\n", + "\n", + "import warnings\n", + "# warnings.filterwarnings('ignore', category=FutureWarning)\n", + "\n", + "# Image manipulations\n", + "from PIL import Image\n", + "\n", + "# Timing utility\n", + "from timeit import default_timer as timer\n", + "\n", + "# Visualizations\n", + "import matplotlib.pyplot as plt\n", + "class Net(nn.Module):\n", + "\n", + " # Convolution as a whole should span at least 1 beat, preferably more\n", + " # Input shape = (1,2500,12)\n", + "# summary(model, input_size=(1, 2500, 12))\n", + " # CLEAR EXPERIMENTS TO TRY\n", + "# Alter kernel size, right now drops to 10 channels at beginning then stays there\n", + "# Try increasing output channel size as you go deeper \n", + "# Alter stride to have larger image at FC layer\n", + "\n", + " def __init__(self):\n", + " super(Net, self).__init__()\n", + "\n", + " base_conv = 64\n", + " self.conv1 = nn.Sequential( \n", + " nn.Conv2d(in_channels=1, out_channels=base_conv, kernel_size=(7,3), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.MaxPool2d(kernel_size=(3,1), stride=(3,1))\n", + " )\n", + "\n", + " self.conv2 = nn.Sequential(\n", + " nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.MaxPool2d(kernel_size=(3,1), stride=(1,1))\n", + " )\n", + "\n", + " self.conv3 = nn.Sequential(\n", + " nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.MaxPool2d(kernel_size=(3,1), stride=(3,1))\n", + " )\n", + "\n", + " self.conv4 = nn.Sequential(\n", + " nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.MaxPool2d(kernel_size=(3,1), stride=(1,1))\n", + " )\n", + "\n", + " self.conv5 = nn.Sequential(\n", + " nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.MaxPool2d(kernel_size=(3,1), stride=(3,1))\n", + " )\n", + " self.conv6 = nn.Sequential(\n", + " nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.MaxPool2d(kernel_size=(3,1), stride=(1,1))\n", + " )\n", + "\n", + " self.conv7 = nn.Sequential(\n", + " nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.MaxPool2d(kernel_size=(3,1), stride=(1,1))\n", + " )\n", + "\n", + " self.fc1 = nn.Sequential(\n", + "# nn.AdaptiveAvgPool2d((1, 1)),\n", + " nn.Flatten(),\n", + " nn.Linear(64*87*10, 64), #64 kernel size, 2500 pooled to 29, \n", + " nn.ReLU(),\n", + " nn.Linear(64, 32),\n", + " nn.ReLU(),\n", + " nn.Linear(32, 1))\n", + "\n", + " def forward(self, x):\n", + " out = self.conv1(x)\n", + "# print(out.shape)\n", + " out = self.conv2(out)\n", + "# print(out.shape) \n", + " out = self.conv3(out)\n", + "# print(out.shape)\n", + " out = self.conv4(out)\n", + "# print(out.shape)\n", + " out = self.conv5(out)\n", + "# print(out.shape)\n", + " out = self.conv6(out)\n", + "# print(out.shape)\n", + " out = self.conv7(out)\n", + "# print(out.shape)\n", + " out = self.fc1(out)\n", + " return out\n", + "\n", + "model = Net()\n", + "print(model)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 109, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Net(\n", + " (conv1): Sequential(\n", + " (0): Conv2d(1, 16, kernel_size=(5, 1), stride=(1, 1))\n", + " (1): ReLU()\n", + " (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (3): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))\n", + " (4): ReLU()\n", + " (5): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (6): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)\n", + " )\n", + " (conv2): Sequential(\n", + " (0): Conv2d(16, 16, kernel_size=(5, 1), stride=(1, 1))\n", + " (1): ReLU()\n", + " (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (3): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))\n", + " (4): ReLU()\n", + " (5): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (6): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)\n", + " )\n", + " (conv3): Sequential(\n", + " (0): Conv2d(16, 32, kernel_size=(5, 1), stride=(1, 1))\n", + " (1): ReLU()\n", + " (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (3): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))\n", + " (4): ReLU()\n", + " (5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (6): MaxPool2d(kernel_size=(4, 1), stride=(4, 1), padding=0, dilation=1, ceil_mode=False)\n", + " )\n", + " (conv4): Sequential(\n", + " (0): Conv2d(32, 32, kernel_size=(3, 1), stride=(1, 1))\n", + " (1): ReLU()\n", + " (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (3): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))\n", + " (4): ReLU()\n", + " (5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (6): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)\n", + " )\n", + " (conv5): Sequential(\n", + " (0): Conv2d(32, 64, kernel_size=(3, 1), stride=(1, 1))\n", + " (1): ReLU()\n", + " (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (4): ReLU()\n", + " (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (6): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)\n", + " )\n", + " (conv6): Sequential(\n", + " (0): Conv2d(64, 64, kernel_size=(3, 1), stride=(1, 1))\n", + " (1): ReLU()\n", + " (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (4): ReLU()\n", + " (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (6): MaxPool2d(kernel_size=(4, 1), stride=(4, 1), padding=0, dilation=1, ceil_mode=False)\n", + " )\n", + " (conv7): Sequential(\n", + " (0): Conv2d(64, 64, kernel_size=(1, 12), stride=(1, 1))\n", + " (1): ReLU()\n", + " (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (fc1): Sequential(\n", + " (0): Flatten()\n", + " (1): Linear(in_features=512, out_features=64, bias=True)\n", + " (2): ReLU()\n", + " (3): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (4): Linear(in_features=64, out_features=32, bias=True)\n", + " (5): ReLU()\n", + " (6): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (7): Linear(in_features=32, out_features=1, bias=True)\n", + " )\n", + ")\n" + ] + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "class Net(nn.Module):\n", + " def __init__(self):\n", + " super(Net, self).__init__()\n", + "\n", + " base_conv = 16\n", + " self.conv1 = nn.Sequential( \n", + " nn.Conv2d(in_channels=1, out_channels=base_conv, kernel_size=(5,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.MaxPool2d(kernel_size=(2,1), stride=(2,1)) \n", + " )\n", + "\n", + " self.conv2 = nn.Sequential(\n", + " nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(5,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv),\n", + " nn.MaxPool2d(kernel_size=(2,1), stride=(2,1)) \n", + " )\n", + "\n", + " self.conv3 = nn.Sequential(\n", + " nn.Conv2d(in_channels=base_conv, out_channels=base_conv*2, kernel_size=(5,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv*2),\n", + " nn.Conv2d(in_channels=base_conv*2, out_channels=base_conv*2, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv*2),\n", + " nn.MaxPool2d(kernel_size=(4,1), stride=(4,1)) \n", + " )\n", + "\n", + " self.conv4 = nn.Sequential(\n", + " nn.Conv2d(in_channels=base_conv*2, out_channels=base_conv*2, kernel_size=(3,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv*2),\n", + " nn.Conv2d(in_channels=base_conv*2, out_channels=base_conv*2, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv*2),\n", + " nn.MaxPool2d(kernel_size=(2,1), stride=(2,1)) \n", + " )\n", + "\n", + " self.conv5 = nn.Sequential(\n", + " nn.Conv2d(in_channels=base_conv*2, out_channels=base_conv*4, kernel_size=(3,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv*4),\n", + " nn.Conv2d(in_channels=base_conv*4, out_channels=base_conv*4, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv*4),\n", + " nn.MaxPool2d(kernel_size=(2,1), stride=(2,1)) \n", + " )\n", + " self.conv6 = nn.Sequential(\n", + " nn.Conv2d(in_channels=base_conv*4, out_channels=base_conv*4, kernel_size=(3,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv*4),\n", + " nn.Conv2d(in_channels=base_conv*4, out_channels=base_conv*4, kernel_size=(1,1), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv*4),\n", + " nn.MaxPool2d(kernel_size=(4,1), stride=(4,1)) \n", + " )\n", + "\n", + " self.conv7 = nn.Sequential(\n", + " nn.Conv2d(in_channels=base_conv*4, out_channels=base_conv*4, kernel_size=(1,12), stride=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(base_conv*4),\n", + " ) #Mayo describess using their 7th block as a spatial fusion block\n", + "\n", + " self.fc1 = nn.Sequential(\n", + "# nn.AdaptiveAvgPool2d((1, 1)),\n", + " nn.Flatten(),\n", + " nn.Linear(64*8*1, 64), #64 kernel size, #2500 decreased to 8, x1 lead\n", + " nn.ReLU(),\n", + " nn.BatchNorm1d(64),\n", + " nn.Linear(64, 32),\n", + " nn.ReLU(),\n", + " nn.BatchNorm1d(32),\n", + "# nn.Dropout(0.5),\n", + " nn.Linear(32, 1))\n", + "\n", + " def forward(self, x):\n", + " out = self.conv1(x)\n", + "# print(out.shape)\n", + " out = self.conv2(out)\n", + "# print(out.shape) \n", + " out = self.conv3(out)\n", + "# print(out.shape)\n", + " out = self.conv4(out)\n", + "# print(out.shape)\n", + " out = self.conv5(out)\n", + "# print(out.shape)\n", + " out = self.conv6(out)\n", + "# print(out.shape)\n", + " out = self.conv7(out)\n", + "# print(out.shape)\n", + " out = self.fc1(out)\n", + " return out\n", + "\n", + "model = Net()\n", + "print(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 110, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------------------------------------\n", + " Layer (type) Output Shape Param #\n", + "================================================================\n", + " Conv2d-1 [-1, 16, 2496, 12] 96\n", + " ReLU-2 [-1, 16, 2496, 12] 0\n", + " BatchNorm2d-3 [-1, 16, 2496, 12] 32\n", + " Conv2d-4 [-1, 16, 2496, 12] 272\n", + " ReLU-5 [-1, 16, 2496, 12] 0\n", + " BatchNorm2d-6 [-1, 16, 2496, 12] 32\n", + " MaxPool2d-7 [-1, 16, 1248, 12] 0\n", + " Conv2d-8 [-1, 16, 1244, 12] 1,296\n", + " ReLU-9 [-1, 16, 1244, 12] 0\n", + " BatchNorm2d-10 [-1, 16, 1244, 12] 32\n", + " Conv2d-11 [-1, 16, 1244, 12] 272\n", + " ReLU-12 [-1, 16, 1244, 12] 0\n", + " BatchNorm2d-13 [-1, 16, 1244, 12] 32\n", + " MaxPool2d-14 [-1, 16, 622, 12] 0\n", + " Conv2d-15 [-1, 32, 618, 12] 2,592\n", + " ReLU-16 [-1, 32, 618, 12] 0\n", + " BatchNorm2d-17 [-1, 32, 618, 12] 64\n", + " Conv2d-18 [-1, 32, 618, 12] 1,056\n", + " ReLU-19 [-1, 32, 618, 12] 0\n", + " BatchNorm2d-20 [-1, 32, 618, 12] 64\n", + " MaxPool2d-21 [-1, 32, 154, 12] 0\n", + " Conv2d-22 [-1, 32, 152, 12] 3,104\n", + " ReLU-23 [-1, 32, 152, 12] 0\n", + " BatchNorm2d-24 [-1, 32, 152, 12] 64\n", + " Conv2d-25 [-1, 32, 152, 12] 1,056\n", + " ReLU-26 [-1, 32, 152, 12] 0\n", + " BatchNorm2d-27 [-1, 32, 152, 12] 64\n", + " MaxPool2d-28 [-1, 32, 76, 12] 0\n", + " Conv2d-29 [-1, 64, 74, 12] 6,208\n", + " ReLU-30 [-1, 64, 74, 12] 0\n", + " BatchNorm2d-31 [-1, 64, 74, 12] 128\n", + " Conv2d-32 [-1, 64, 74, 12] 4,160\n", + " ReLU-33 [-1, 64, 74, 12] 0\n", + " BatchNorm2d-34 [-1, 64, 74, 12] 128\n", + " MaxPool2d-35 [-1, 64, 37, 12] 0\n", + " Conv2d-36 [-1, 64, 35, 12] 12,352\n", + " ReLU-37 [-1, 64, 35, 12] 0\n", + " BatchNorm2d-38 [-1, 64, 35, 12] 128\n", + " Conv2d-39 [-1, 64, 35, 12] 4,160\n", + " ReLU-40 [-1, 64, 35, 12] 0\n", + " BatchNorm2d-41 [-1, 64, 35, 12] 128\n", + " MaxPool2d-42 [-1, 64, 8, 12] 0\n", + " Conv2d-43 [-1, 64, 8, 1] 49,216\n", + " ReLU-44 [-1, 64, 8, 1] 0\n", + " BatchNorm2d-45 [-1, 64, 8, 1] 128\n", + " Flatten-46 [-1, 512] 0\n", + " Linear-47 [-1, 64] 32,832\n", + " ReLU-48 [-1, 64] 0\n", + " BatchNorm1d-49 [-1, 64] 128\n", + " Linear-50 [-1, 32] 2,080\n", + " ReLU-51 [-1, 32] 0\n", + " BatchNorm1d-52 [-1, 32] 64\n", + " Linear-53 [-1, 1] 33\n", + "================================================================\n", + "Total params: 122,001\n", + "Trainable params: 122,001\n", + "Non-trainable params: 0\n", + "----------------------------------------------------------------\n", + "Input size (MB): 0.11\n", + "Forward/backward pass size (MB): 53.93\n", + "Params size (MB): 0.47\n", + "Estimated Total Size (MB): 54.51\n", + "----------------------------------------------------------------\n" + ] + } + ], + "source": [ + "summary(model, input_size=(1, 2500, 12))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}