Diff of /utils/segmentation.py [000000] .. [addb71]

Switch to unified view

a b/utils/segmentation.py
1
import cv2
2
import numpy as np
3
import os
4
from tensorflow.keras.models import load_model
5
6
# Load model at module level (only once)
7
MODEL_PATH = os.path.join(os.path.dirname(__file__), '..', 'app', 'model', 'unet_model.h5')
8
try:
9
    unet_model = load_model(MODEL_PATH)
10
    print(f"Successfully loaded model from {MODEL_PATH}")
11
except Exception as e:
12
    raise RuntimeError(f"Failed to load U-Net model: {str(e)}")
13
14
def smart_resize(image, target_size=(128, 128)):
15
    """Resizes with aspect ratio preservation using zero-padding"""
16
    h, w = image.shape[:2]
17
    scale = min(target_size[0]/h, target_size[1]/w)
18
    new_h, new_w = int(h * scale), int(w * scale)
19
    
20
    resized = cv2.resize(image, (new_w, new_h))
21
    delta_h = target_size[0] - new_h
22
    delta_w = target_size[1] - new_w
23
    
24
    padded = cv2.copyMakeBorder(resized,
25
                              delta_h//2, delta_h - delta_h//2,
26
                              delta_w//2, delta_w - delta_w//2,
27
                              cv2.BORDER_CONSTANT, value=0)
28
    return padded
29
30
def segment_lung(image_path):
31
    """Processes any X-ray image to segmentation mask"""
32
    try:
33
        # Read and validate
34
        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
35
        if image is None:
36
            raise ValueError("Invalid image file")
37
        
38
        # Preprocess
39
        processed = smart_resize(image)
40
        processed = processed.astype(np.float32) / 255.0
41
        input_tensor = np.expand_dims(processed, axis=(0, -1))
42
        
43
        # Predict
44
        mask = unet_model.predict(input_tensor)[0]
45
        binary_mask = (mask > 0.5).astype(np.uint8) * 255
46
        
47
        # Resize back to original for clinical use
48
        final_mask = cv2.resize(binary_mask, (image.shape[1], image.shape[0]))
49
        return final_mask
50
        
51
    except Exception as e:
52
        raise ValueError(f"Segmentation failed: {str(e)}")
53
54
55
56
57
58