--- a +++ b/SemanticSegmentation.m @@ -0,0 +1,156 @@ +%% This code train and test semantic segmentation of Capsule endoscopy images +% required MATLAB 2018 +% Developer: Tonmoy Ghosh (tghosh@crimson.ua.edu) + +% Load Images +% Use |imageDatastore| to load images. The |imageDatastore| enables you +% to efficiently load a large collection of images on disk. +%% +clear; clc; close all; +%imgDir = fullfile('bleeding images') +imgDir = '/Users/tonmoyghosh/OneDrive - The University of Alabama/Paper with Code/Semantic Segmentation Based Bleeding Zone Detection/Dataset/bleeding'; +imds = imageDatastore(imgDir); +%% +% Display one of the images. + +I = readimage(imds, 1); +I = histeq(I); +figure +imshow(I) +%% Load Pixel-Labeled Images +classes = [ + "Bleeding" + "Non_Bleeding" + "Background" + ]; +labelIDs = PixelLabelIDs(); +%% +% Use the classes and label IDs to create the |pixelLabelDatastore|: + +%labelDir = fullfile('labels'); +labelDir = '/Users/tonmoyghosh/OneDrive - The University of Alabama/Paper with Code/Semantic Segmentation Based Bleeding Zone Detection/Dataset/labels'; +pxds = pixelLabelDatastore(labelDir,classes,labelIDs); +% Read and display one of the pixel-labeled images by overlaying it on top +% of an image. + +C = readimage(pxds, 1); + + +cmap = CEColorMap; +B = labeloverlay(I,C,'ColorMap',cmap); + +figure +imshow(B) +pixelLabelColorbar(cmap,classes); + +%% +%analize the data statistics +tbl = countEachLabel(pxds) + + +%Visualize the pixel counts by class. + +frequency = tbl.PixelCount/sum(tbl.PixelCount); + +figure +bar(1:numel(classes),frequency) +xticks(1:numel(classes)) +xticklabels(tbl.Name) +xtickangle(45) +ylabel('Frequency') + +%% +%Resize CamVid Data +imageFolder = fullfile('imagesReszed',filesep); +imds = resizeCEImages(imds,imageFolder); + +labelFolder = fullfile('labelsResized',filesep); +pxds = resizeCEPixelLabels(pxds,labelFolder); + + + +%% +%Prepare Training and Test Sets +[imdsTrain, imdsTest, pxdsTrain, pxdsTest] = partitionCEData(imds,pxds); + +numTrainingImages = numel(imdsTrain.Files) +numTestingImages = numel(imdsTest.Files) + +%Create the network +imageSize = [256 256 3]; +numClasses = numel(classes); +%lgraph = segnetLayers(imageSize,numClasses,'vgg16'); + + +%% + +%Balance Classes Using Class Weighting +% imageFreq = tbl.PixelCount ./ tbl.ImagePixelCount; +% classWeights = median(imageFreq) ./ imageFreq +% +% pxLayer = pixelClassificationLayer('Name','labels','ClassNames', tbl.Name, 'ClassWeights', classWeights) +% +% +% lgraph = removeLayers(lgraph, 'pixelLabels'); +% lgraph = addLayers(lgraph, pxLayer); +% lgraph = connectLayers(lgraph, 'softmax' ,'labels'); + +% load saved network architecture +load lgraph + +%Select Training Options +options = trainingOptions('sgdm', ... + 'Momentum', 0.9, ... + 'InitialLearnRate', 1e-3, ... + 'L2Regularization', 0.0005, ... + 'MaxEpochs', 100, ... + 'MiniBatchSize', 3, ... + 'Shuffle', 'every-epoch', ... + 'VerboseFrequency', 2); + + + +%% + +%Data Augmentation +augmenter = imageDataAugmenter('RandXReflection',true,... + 'RandXTranslation', [-10 10], 'RandYTranslation',[-10 10]); + + +%Start Training +datasource = pixelLabelImageSource(imdsTrain,pxdsTrain, ... + 'DataAugmentation',augmenter); + +doTraining = false; +if doTraining + [net, info] = trainNetwork(datasource,lgraph,options); +else + data = load('CEtrainedSegNet.mat'); + net = data.net; +end + +%% +%Test Network on One Image +tic +I = read(imdsTest); +C = semanticseg(I, net); + +%Display the results. +B = labeloverlay(I, C, 'Colormap', cmap, 'Transparency',0.4); +figure +imshow(B) +pixelLabelColorbar(cmap, classes); + +expectedResult = read(pxdsTest); +actual = uint8(C); +expected = uint8(expectedResult{1}); +imshowpair(actual, expected) + +iou = jaccard(C, expectedResult{1}); +table(classes,iou) + +%Evaluate Trained Network +pxdsResults = semanticseg(imdsTest,net,'WriteLocation',tempdir,'Verbose',false); +metrics = evaluateSemanticSegmentation(pxdsResults,pxdsTest,'Verbose',false); + +toc