Diff of /SemanticSegmentation.m [000000] .. [06669b]

Switch to unified view

a b/SemanticSegmentation.m
1
%% This code train and test semantic segmentation of Capsule endoscopy images
2
% required MATLAB 2018
3
% Developer: Tonmoy Ghosh (tghosh@crimson.ua.edu)
4
5
% Load Images
6
% Use |imageDatastore| to load images. The |imageDatastore| enables you 
7
% to efficiently load a large collection of images on disk.
8
%%
9
clear; clc; close all;
10
%imgDir = fullfile('bleeding images')
11
imgDir = '/Users/tonmoyghosh/OneDrive - The University of Alabama/Paper with Code/Semantic Segmentation Based Bleeding Zone Detection/Dataset/bleeding';
12
imds = imageDatastore(imgDir);
13
%% 
14
% Display one of the images.
15
16
I = readimage(imds, 1);
17
I = histeq(I);
18
figure
19
imshow(I)
20
%% Load Pixel-Labeled Images
21
classes = [
22
    "Bleeding"
23
    "Non_Bleeding"
24
    "Background"
25
    ];
26
labelIDs = PixelLabelIDs();
27
%% 
28
% Use the classes and label IDs to create the |pixelLabelDatastore|:
29
30
%labelDir = fullfile('labels');
31
labelDir = '/Users/tonmoyghosh/OneDrive - The University of Alabama/Paper with Code/Semantic Segmentation Based Bleeding Zone Detection/Dataset/labels';
32
pxds = pixelLabelDatastore(labelDir,classes,labelIDs);
33
% Read and display one of the pixel-labeled images by overlaying it on top 
34
% of an image.
35
36
C = readimage(pxds, 1);
37
38
39
cmap = CEColorMap;
40
B = labeloverlay(I,C,'ColorMap',cmap);
41
42
figure
43
imshow(B)
44
pixelLabelColorbar(cmap,classes);
45
46
%%
47
%analize the data statistics
48
tbl = countEachLabel(pxds)
49
50
51
%Visualize the pixel counts by class.
52
53
frequency = tbl.PixelCount/sum(tbl.PixelCount);
54
55
figure
56
bar(1:numel(classes),frequency)
57
xticks(1:numel(classes))
58
xticklabels(tbl.Name)
59
xtickangle(45)
60
ylabel('Frequency')
61
62
%%
63
%Resize CamVid Data
64
imageFolder = fullfile('imagesReszed',filesep);
65
imds = resizeCEImages(imds,imageFolder);
66
67
labelFolder = fullfile('labelsResized',filesep);
68
pxds = resizeCEPixelLabels(pxds,labelFolder);
69
70
71
72
%%
73
%Prepare Training and Test Sets
74
[imdsTrain, imdsTest, pxdsTrain, pxdsTest] = partitionCEData(imds,pxds);
75
76
numTrainingImages = numel(imdsTrain.Files)
77
numTestingImages = numel(imdsTest.Files)
78
79
%Create the network
80
imageSize = [256 256 3];
81
numClasses = numel(classes);
82
%lgraph = segnetLayers(imageSize,numClasses,'vgg16');
83
84
85
%%
86
87
%Balance Classes Using Class Weighting
88
% imageFreq = tbl.PixelCount ./ tbl.ImagePixelCount;
89
% classWeights = median(imageFreq) ./ imageFreq
90
% 
91
% pxLayer = pixelClassificationLayer('Name','labels','ClassNames', tbl.Name, 'ClassWeights', classWeights)
92
% 
93
% 
94
% lgraph = removeLayers(lgraph, 'pixelLabels');
95
% lgraph = addLayers(lgraph, pxLayer);
96
% lgraph = connectLayers(lgraph, 'softmax' ,'labels');
97
98
% load saved network architecture
99
load lgraph
100
101
%Select Training Options
102
options = trainingOptions('sgdm', ...
103
    'Momentum', 0.9, ...
104
    'InitialLearnRate', 1e-3, ...
105
    'L2Regularization', 0.0005, ...
106
    'MaxEpochs', 100, ...
107
    'MiniBatchSize', 3, ...
108
    'Shuffle', 'every-epoch', ...
109
    'VerboseFrequency', 2);
110
111
112
113
%%
114
115
%Data Augmentation
116
augmenter = imageDataAugmenter('RandXReflection',true,...
117
    'RandXTranslation', [-10 10], 'RandYTranslation',[-10 10]);
118
119
120
%Start Training
121
datasource = pixelLabelImageSource(imdsTrain,pxdsTrain, ...
122
                    'DataAugmentation',augmenter);
123
124
doTraining = false;
125
if doTraining
126
    [net, info] = trainNetwork(datasource,lgraph,options);
127
else
128
    data = load('CEtrainedSegNet.mat');
129
    net = data.net;
130
end
131
132
%%
133
%Test Network on One Image
134
tic
135
I = read(imdsTest);
136
C = semanticseg(I, net);
137
138
%Display the results.
139
B = labeloverlay(I, C, 'Colormap', cmap, 'Transparency',0.4);
140
figure
141
imshow(B)
142
pixelLabelColorbar(cmap, classes);
143
144
expectedResult = read(pxdsTest);
145
actual = uint8(C);
146
expected = uint8(expectedResult{1});
147
imshowpair(actual, expected)
148
149
iou = jaccard(C, expectedResult{1});
150
table(classes,iou)
151
152
%Evaluate Trained Network
153
pxdsResults = semanticseg(imdsTest,net,'WriteLocation',tempdir,'Verbose',false);
154
metrics = evaluateSemanticSegmentation(pxdsResults,pxdsTest,'Verbose',false);
155
156
toc