|
a |
|
b/demo.py |
|
|
1 |
import torch |
|
|
2 |
from EfficientNet_2d.EfficientNet_2d import get_pretrained_EfficientNet, get_pretrained_DAR |
|
|
3 |
|
|
|
4 |
|
|
|
5 |
if __name__ == "__main__": |
|
|
6 |
# Phase 1: |
|
|
7 |
# pre-train prd-net, cf-net, and lr-net on CR-set, IC-set, and LR-set, respectively, and save the pre-trained model |
|
|
8 |
prd_net = get_pretrained_EfficientNet(num_classes=5) |
|
|
9 |
cf_net = get_pretrained_EfficientNet(num_classes=5) |
|
|
10 |
lr_net = get_pretrained_EfficientNet(num_classes=5) |
|
|
11 |
|
|
|
12 |
# Phase 2: |
|
|
13 |
# fine-tune dar on CR-set |
|
|
14 |
prd_params_path = "../your_checkpoint_path/prd_net_checkpoint.pth" |
|
|
15 |
cf_params_path = "../your_checkpoint_path/cf_net_checkpoint.pth" |
|
|
16 |
lr_params_path = "../your_checkpoint_path/lr_net_checkpoint.pth" |
|
|
17 |
|
|
|
18 |
prd_params = torch.load(prd_params_path) |
|
|
19 |
cf_params = torch.load(cf_params_path) |
|
|
20 |
lr_params = torch.load(lr_params_path) |
|
|
21 |
|
|
|
22 |
model = get_pretrained_DAR(prd_params, cf_params, lr_params, num_classes=5) |
|
|
23 |
|
|
|
24 |
# prediction |
|
|
25 |
imgs = torch.rand([4, 3, 224, 224]) |
|
|
26 |
prd_preds, cf_preds, lr_preds = model(imgs) |
|
|
27 |
_, preds = torch.softmax(prd_preds, dim=1).max(dim=1) |