Skip to content

Commit a0a8ecc

Browse files
Merge pull request #13 from matlab-deep-learning/bugfix/12
Fixes #12
2 parents fb1fec4 + 755e64f commit a0a8ecc

File tree

6 files changed

+42
-41
lines changed

6 files changed

+42
-41
lines changed

code/mtcnn/+mtcnn/+util/DagNetworkStrategy.m

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,16 @@
55
Pnet
66
Rnet
77
Onet
8+
ExecutionEnvironment
89
end
910

1011
methods
11-
function obj = DagNetworkStrategy()
12+
function obj = DagNetworkStrategy(useGpu)
13+
if useGpu
14+
obj.ExecutionEnvironment = "gpu";
15+
else
16+
obj.ExecutionEnvironment = "cpu";
17+
end
1218
end
1319

1420
function load(obj)
@@ -18,19 +24,24 @@ function load(obj)
1824
obj.Onet = importdata(fullfile(mtcnnRoot(), "weights", "dagONet.mat"));
1925
end
2026

21-
function pnet = getPNet(obj)
22-
pnet = obj.Pnet;
27+
function [probability, correction] = applyPNet(obj, im)
28+
% need to use activations as we don't know what size it will be
29+
result = obj.Pnet.activations(im, "concat", ...
30+
"ExecutionEnvironment", obj.ExecutionEnvironment);
31+
32+
probability = result(:,:,1:2,:);
33+
correction = result(:,:,3:end,:);
2334
end
2435

2536
function [probs, correction] = applyRNet(obj, im)
26-
output = obj.Rnet.predict(im);
37+
output = obj.Rnet.predict(im, "ExecutionEnvironment", obj.ExecutionEnvironment);
2738

2839
probs = output(:,1:2);
2940
correction = output(:,3:end);
3041
end
3142

3243
function [probs, correction, landmarks] = applyONet(obj, im)
33-
output = obj.Onet.predict(im);
44+
output = obj.Onet.predict(im, "ExecutionEnvironment", obj.ExecutionEnvironment);
3445

3546
probs = output(:,1:2);
3647
correction = output(:,3:6);

code/mtcnn/+mtcnn/+util/DlNetworkStrategy.m

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,13 @@ function load(obj)
2626
end
2727
end
2828

29-
function pnet = getPNet(obj)
30-
pnet = obj.PnetWeights;
29+
function [probability, correction] = applyPNet(obj, im)
30+
im = dlarray(im, "SSCB");
31+
32+
[probability, correction] = mtcnn.pnet(im, obj.PnetWeights);
33+
34+
probability = extractdata(gather(probability));
35+
correction = extractdata(gather(correction));
3136
end
3237

3338
function [probs, correction] = applyRNet(obj, im)

code/mtcnn/+mtcnn/Detector.m

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
end
5050

5151
if obj.UseDagNet
52-
obj.Networks = mtcnn.util.DagNetworkStrategy();
52+
obj.Networks = mtcnn.util.DagNetworkStrategy(obj.UseGPU);
5353
else
5454
obj.Networks = mtcnn.util.DlNetworkStrategy(obj.UseGPU);
5555
end
@@ -89,7 +89,7 @@
8989
[thisBox, thisScore] = ...
9090
mtcnn.proposeRegions(im, scale, ...
9191
obj.ConfidenceThresholds(1), ...
92-
obj.Networks.getPNet());
92+
obj.Networks);
9393
bboxes = cat(1, bboxes, thisBox);
9494
scores = cat(1, scores, thisScore);
9595
end
@@ -183,7 +183,7 @@
183183
"Input image is of unsupported type '%s'", class(im));
184184
end
185185

186-
if obj.UseGPU()
186+
if obj.UseGPU && ~obj.UseDagNet
187187
outIm = gpuArray(outIm);
188188
end
189189

code/mtcnn/+mtcnn/proposeRegions.m

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function [bboxes, scores] = proposeRegions(im, scale, threshold, weightsOrNet)
1+
function [bboxes, scores] = proposeRegions(im, scale, threshold, networkStrategy)
22
% proposeRegions Generate region proposals at a given scale.
33
%
44
% Args:
@@ -13,34 +13,15 @@
1313

1414
% Copyright 2019 The MathWorks, Inc.
1515

16-
useDagNet = isa(weightsOrNet, "DAGNetwork");
17-
if isa(im, "gpuArray")
18-
imClass = classUnderlying(im);
19-
else
20-
imClass = class(im);
21-
end
22-
assert(imClass == "single", "mtcnn:proposeRegions:wrongImageType", ...
23-
"Input image should be a single scale -1 to 1");
2416

2517
% Stride of the proposal network
2618
stride = 2;
2719
% Field of view of the proposal network in pixels
2820
pnetSize = 12;
2921

3022
im = imresize(im, 1/scale);
31-
32-
if useDagNet
33-
% need to use activations as we don't know what size it will be
34-
result = weightsOrNet.activations(im, "concat");
35-
probability = gather(result(:,:,1:2,:));
36-
correction = gather(result(:,:,3:end,:));
37-
else
38-
im = dlarray(im, "SSCB");
39-
[probability, correction] = mtcnn.pnet(im, weightsOrNet);
40-
probability = extractdata(gather(probability));
41-
correction = extractdata(gather(correction));
42-
end
43-
23+
24+
[probability, correction] = networkStrategy.applyPNet(im);
4425

4526
faces = probability(:,:,2) > threshold;
4627
if sum(faces, 'all') == 0

test/+tests/DetectorTest.m

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,14 @@ function testNmsThresholds(test)
112112
end
113113

114114
%% GPU
115-
function testGpuDetect(test)
115+
function testGpuDetect(test, imageTypeConversion, useDagNet)
116+
116117
% filter if no GPU present
117118
test.assumeGreaterThan(gpuDeviceCount, 0);
118119

119-
detector = mtcnn.Detector("UseGPU", true);
120-
[bboxes, scores, landmarks] = detector.detect(test.Image);
120+
inputImage = imageTypeConversion(test.Image);
121+
detector = mtcnn.Detector("UseGPU", true, "UseDagNet", useDagNet);
122+
[bboxes, scores, landmarks] = detector.detect(inputImage);
121123

122124
test.verifyEqual(size(bboxes), [6, 4]);
123125
test.verifyEqual(size(scores), [6, 1]);

test/+tests/ProposeRegionsTest.m

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,18 @@
88
end
99

1010
properties (TestParameter)
11-
getNet = struct("weights", @() load(fullfile(mtcnnRoot, "weights", "pnet.mat")), ...
12-
"net", @() importdata(fullfile(mtcnnRoot, "weights", "dagPNet.mat")));
11+
getNet = struct("dl", @() mtcnn.util.DlNetworkStrategy(false) , ...
12+
"dag", @() mtcnn.util.DagNetworkStrategy(false));
1313
end
1414

1515
methods (Test)
1616
function testOutputs(test, getNet)
1717
scale = 2;
1818
conf = 0.5;
19-
weights = getNet();
19+
strategy = getNet();
20+
strategy.load();
2021

21-
[box, score] = mtcnn.proposeRegions(test.Image, scale, conf, weights);
22+
[box, score] = mtcnn.proposeRegions(test.Image, scale, conf, strategy);
2223

2324
test.verifyOutputs(box, score);
2425
end
@@ -29,9 +30,10 @@ function test1DActivations(test, getNet)
2930
cropped = imcrop(test.Image, [300, 42, 65, 38]);
3031
scale = 3;
3132
conf = 0.5;
32-
weights = getNet();
33+
strategy = getNet();
34+
strategy.load();
3335

34-
[box, score] = mtcnn.proposeRegions(cropped, scale, conf, weights);
36+
[box, score] = mtcnn.proposeRegions(cropped, scale, conf, strategy);
3537

3638
test.verifyOutputs(box, score);
3739
end

0 commit comments

Comments
 (0)