diff --git a/test.lua b/test.lua index fb65bd90..02b63d43 100644 --- a/test.lua +++ b/test.lua @@ -978,9 +978,9 @@ local function BatchNormalization_backward(moduleName, mode, inputSize, backward end end -local function testBatchNormalization(name, dim, k) +local function testBatchNormalization(name, dim, k, batchsize) local function inputSize() - local inputSize = { torch.random(2,32), torch.random(1, k) } + local inputSize = { batchsize or torch.random(2,32), torch.random(1, k) } for i=1,dim do table.insert(inputSize, torch.random(1,k)) end @@ -1005,6 +1005,7 @@ end function cunntest.BatchNormalization() testBatchNormalization('BatchNormalization', 0, 128) + testBatchNormalization('BatchNormalization', 0, 128, 1) -- test batchsize=1 end function cunntest.SpatialBatchNormalization()