Diff of /test.py [000000] .. [ab8373]

Switch to unified view

a b/test.py
1
import os
2
import csv
3
import torch
4
from torchvision import transforms
5
from torch.utils.data import Dataset, DataLoader
6
from PIL import Image
7
from ultralytics import YOLO
8
9
10
TEST_DATASET = ""
11
YOLO_WEIGHTS = "/home/mlip/Desktop/bleedgen/yolo_v8_runs/runs/detect/train_valid_4/weights/best.pt"
12
13
transform = transforms.Compose([
14
    transforms.ToTensor(),           
15
])
16
17
class CustomDataset(Dataset):
18
    def __init__(self, root_dir, transform=None):
19
        self.root_dir = root_dir
20
        self.transform = transform
21
        self.images = os.listdir(root_dir)
22
23
    def __len__(self):
24
        return len(self.images)
25
26
    def __getitem__(self, idx):
27
        img_name = os.path.join(self.root_dir, self.images[idx])
28
        image = Image.open(img_name)
29
        if self.transform:
30
            image = self.transform(image)
31
        return img_name, image
32
33
test_dataset = CustomDataset(TEST_DATASET, transform=transform)
34
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
35
36
device = "cuda" if torch.cuda.is_available() else "cpu"
37
yolo_model = YOLO(YOLO_WEIGHTS)
38
39
classes = ['bleeding', 'non_bleeding']
40
with open('predictions.csv', mode='w', newline='') as file:
41
    writer = csv.writer(file)
42
    writer.writerow(['Image Name', 'Predicted Label(YOLOv8)', 'YOLO CONFIDENCE', "BOUNDING BOXES (xywhn)"])  # Write header
43
44
    for image_name, image in test_loader:
45
        image = image.to(device)
46
        yolo_predicted_label = "non_bleeding"
47
        with torch.no_grad():
48
            results = yolo_model.predict(image / 2.64)
49
            yolo_row = []
50
            for result in results :
51
                conf = result.boxes.conf.tolist()
52
                if (len(conf) > 0) :
53
                    yolo_predicted_label = 'bleeding'
54
                    yolo_row.append(result.boxes.conf.tolist())
55
                    yolo_row.append(result.boxes.xywhn.tolist())
56
57
        writer.writerow([image_name[0].split('/')[-1], yolo_predicted_label] + yolo_row)
58
59
print("CSV file with predictions saved as 'predictions.csv'.")