Diff of /demo.py [000000] .. [b76f82]

Switch to unified view

a b/demo.py
1
import torch
2
from EfficientNet_2d.EfficientNet_2d import get_pretrained_EfficientNet, get_pretrained_DAR
3
4
5
if __name__ == "__main__":
6
    # Phase 1:
7
    # pre-train prd-net, cf-net, and lr-net on CR-set, IC-set, and LR-set, respectively, and save the pre-trained model
8
    prd_net = get_pretrained_EfficientNet(num_classes=5)
9
    cf_net = get_pretrained_EfficientNet(num_classes=5)
10
    lr_net = get_pretrained_EfficientNet(num_classes=5)
11
12
    # Phase 2:
13
    # fine-tune dar on CR-set
14
    prd_params_path = "../your_checkpoint_path/prd_net_checkpoint.pth"
15
    cf_params_path = "../your_checkpoint_path/cf_net_checkpoint.pth"
16
    lr_params_path = "../your_checkpoint_path/lr_net_checkpoint.pth"
17
18
    prd_params = torch.load(prd_params_path)
19
    cf_params = torch.load(cf_params_path)
20
    lr_params = torch.load(lr_params_path)
21
22
    model = get_pretrained_DAR(prd_params, cf_params, lr_params, num_classes=5)
23
24
    # prediction
25
    imgs = torch.rand([4, 3, 224, 224])
26
    prd_preds, cf_preds, lr_preds = model(imgs)
27
    _, preds = torch.softmax(prd_preds, dim=1).max(dim=1)