Skip to content

Commit c5086ea

Browse files
Merge pull request #10 from matlab-deep-learning/feature/5
Perform image type conversion upfront
2 parents 90a65ea + 6a0ebdf commit c5086ea

File tree

5 files changed

+54
-12
lines changed

5 files changed

+54
-12
lines changed

code/mtcnn/+mtcnn/Detector.m

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,7 @@
7070
%
7171
% See also: mtcnn.detectFaces
7272

73-
if obj.UseGPU()
74-
im = gpuArray(single(im));
75-
end
73+
im = obj.prepImage(im);
7674

7775
bboxes = [];
7876
scores = [];
@@ -103,7 +101,7 @@
103101
end
104102

105103
%% Stage 2 - Refinement
106-
[cropped, bboxes] = obj.prepImages(im, bboxes, obj.RnetSize);
104+
[cropped, bboxes] = obj.prepBbox(im, bboxes, obj.RnetSize);
107105
[probs, correction] = mtcnn.rnet(cropped, obj.RnetWeights);
108106
[scores, bboxes] = obj.processOutputs(probs, correction, bboxes, 2);
109107

@@ -112,7 +110,7 @@
112110
end
113111

114112
%% Stage 3 - Output
115-
[cropped, bboxes] = obj.prepImages(im, bboxes, obj.OnetSize);
113+
[cropped, bboxes] = obj.prepBbox(im, bboxes, obj.OnetSize);
116114

117115
% Adjust bboxes for the behaviour of imcrop
118116
bboxes(:, 1:2) = bboxes(:, 1:2) - 0.5;
@@ -144,12 +142,12 @@ function loadWeights(obj)
144142
obj.OnetWeights = load(fullfile(mtcnnRoot(), "weights", "onet.mat"));
145143
end
146144

147-
function [cropped, bboxes] = prepImages(obj, im, bboxes, outputSize)
145+
function [cropped, bboxes] = prepBbox(obj, im, bboxes, outputSize)
148146
% prepImages Pre-process the images and bounding boxes.
149147
bboxes = mtcnn.util.makeSquare(bboxes);
150148
bboxes = round(bboxes);
151149
cropped = mtcnn.util.cropImage(im, bboxes, outputSize);
152-
cropped = dlarray(single(cropped)./255*2 - 1, "SSCB");
150+
cropped = dlarray(cropped, "SSCB");
153151

154152
end
155153

@@ -167,5 +165,29 @@ function loadWeights(obj)
167165
"OverlapThreshold", obj.NmsThresholds(netIdx));
168166
end
169167
end
168+
169+
function outIm = prepImage(obj, im)
170+
% convert the image to the correct scaling and type
171+
% All images should be scaled to -1 to 1 and of single type
172+
% also place on the GPU if required
173+
174+
switch class(im)
175+
case "uint8"
176+
outIm = single(im)/255*2 - 1;
177+
case "single"
178+
% expect floats to be 0-1 scaled
179+
outIm = im*2 - 1;
180+
case "double"
181+
outIm = single(im)*2 - 1;
182+
otherwise
183+
error("mtcnn:Detector:UnsupportedType", ...
184+
"Input image is of unsupported type '%s'", class(im));
185+
end
186+
187+
if obj.UseGPU()
188+
outIm = gpuArray(outIm);
189+
end
190+
191+
end
170192
end
171193
end

code/mtcnn/+mtcnn/proposeRegions.m

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
% proposeRegions Generate region proposals at a given scale.
33
%
44
% Args:
5-
% im - Input image 0-255 range
5+
% im - Input image -1 to 1 range, type single
66
% scale - Scale to run proposal at
77
% threshold - Confidence threshold to accept proposal
88
% weights - P-Net weights struct
@@ -19,7 +19,7 @@
1919
pnetSize = 12;
2020

2121
im = imresize(im, 1/scale);
22-
im = dlarray(single(im)./255*2 - 1, "SSCB");
22+
im = dlarray(im, "SSCB");
2323

2424
[probability, correction] = mtcnn.pnet(im, weights);
2525

test/+tests/DetectorTest.m

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@
88
Reference
99
end
1010

11+
properties (TestParameter)
12+
imageTypeConversion = struct("uint8", @(x) x, ...
13+
"single", @(x) single(x)/255, ...
14+
"double", @(x) double(x)/255)
15+
end
16+
1117
methods (TestClassSetup)
1218
function setupTestImage(test)
1319
test.Image = imread("visionteam.jpg");
@@ -26,10 +32,14 @@ function testCreate(test)
2632
detector = mtcnn.Detector();
2733
end
2834

29-
function testDetectwithDefaults(test)
35+
function testDetectwithDefaults(test, imageTypeConversion)
36+
% Test expected inputs with images of type uint8, single,
37+
% double (float images are scaled 0-1);
3038
detector = mtcnn.Detector();
3139

32-
[bboxes, scores, landmarks] = detector.detect(test.Image);
40+
inputImage = imageTypeConversion(test.Image);
41+
42+
[bboxes, scores, landmarks] = detector.detect(inputImage);
3343

3444
test.verifyEqual(size(bboxes), [6, 4]);
3545
test.verifyEqual(size(scores), [6, 1]);
@@ -118,4 +128,4 @@ function testGpuDetect(test)
118128
test.verifyEqual(landmarks, test.Reference.landmarks, "RelTol", 1e-1);
119129
end
120130
end
121-
end
131+
end

test/makeDetectionReference.m

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
function makeDetectionReference()
2+
% Run the detector in known good config to create reference boxes,
3+
% scores and landmarks for regression tests.
4+
im = imread("visionteam.jpg");
5+
[bboxes, scores, landmarks] = mtcnn.detectFaces(im);
6+
7+
filename = fullfile(mtcnnTestRoot(), "resources", "ref.mat");
8+
save(filename, "bboxes", "scores", "landmarks");
9+
10+
end

test/resources/ref.mat

2.38 KB
Binary file not shown.

0 commit comments

Comments
 (0)