[c1b1c5]: / ViTPose / tests / test_backbones / test_alexnet.py

Download this file

22 lines (15 with data), 452 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmpose.models.backbones import AlexNet
def test_alexnet_backbone():
"""Test alexnet backbone."""
model = AlexNet(-1)
model.train()
imgs = torch.randn(1, 3, 256, 192)
feat = model(imgs)
assert feat.shape == (1, 256, 7, 5)
model = AlexNet(1)
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat.shape == (1, 1)