From 32b63715477d4715fca44c995a1d22a4784d3eff Mon Sep 17 00:00:00 2001 From: Tan Li Boon Date: Wed, 11 Jan 2017 21:24:00 +0800 Subject: [PATCH 1/6] Modified cmake files to enable building again. --- cmake/Dependencies.cmake | 7 +-- cmake/Thirdparty/FindOpenCL.cmake | 88 +++++++++++++++++++++++++++++ cmake/Thirdparty/FindViennaCL.cmake | 11 +++- 3 files changed, 101 insertions(+), 5 deletions(-) create mode 100644 cmake/Thirdparty/FindOpenCL.cmake diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index d5bfbd926d..708628b094 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -91,12 +91,11 @@ ENDIF() IF(USE_OPENCL) FIND_PACKAGE(OpenCL REQUIRED) - IF(NOT OpenCL_FOUND) + IF(NOT OPENCL_FOUND) MESSAGE(SEND_ERROR "OpenCL was requested, but not found.") ELSE() - #MESSAGE(STATUS "Found OpenCL headers at ${OpenCL_INCLUDE_DIRS}") - INCLUDE_DIRECTORIES(SYSTEM ${OpenCL_INCLUDE_DIR}) - LIST(APPEND SINGA_LINKER_LIBS ${OpenCL_LIBRARIES}) + INCLUDE_DIRECTORIES(SYSTEM ${OPENCL_INCLUDE_DIR}) + LIST(APPEND SINGA_LINKER_LIBS ${OPENCL_LIBRARIES}) FIND_PACKAGE(ViennaCL REQUIRED) IF(NOT ViennaCL_FOUND) MESSAGE(SEND_ERROR "ViennaCL is required if OpenCL is enabled.") diff --git a/cmake/Thirdparty/FindOpenCL.cmake b/cmake/Thirdparty/FindOpenCL.cmake new file mode 100644 index 0000000000..c358d8a1a4 --- /dev/null +++ b/cmake/Thirdparty/FindOpenCL.cmake @@ -0,0 +1,88 @@ +# - Find the OpenCL headers and library +# +# Defines the following if found: +# OPENCL_FOUND : TRUE if found, FALSE otherwise +# OPENCL_INCLUDE_DIRS : Include directories for OpenCL +# OPENCL_LIBRARIES : The libraries to link against +# +# The user can set the OPENCLROOT environment variable to help finding OpenCL +# if it is installed in a non-standard place. + +set(ENV_ATISTREAMSDKROOT "$ENV{ATISTREAMSDKROOT}") +if(ENV_ATISTREAMSDKROOT) + set(ENV_OPENCLROOT "$ENV{ATISTREAMSDKROOT}") +endif(ENV_ATISTREAMSDKROOT) + +set(ENV_AMDAPPSDKROOT "$ENV{AMDAPPSDKROOT}") +if(ENV_AMDAPPSDKROOT) + set(ENV_OPENCLROOT "$ENV{AMDAPPSDKROOT}") +endif(ENV_AMDAPPSDKROOT) + +set(ENV_INTELOCLSDKROOT "$ENV{INTELOCLSDKROOT}") +if(ENV_INTELOCLSDKROOT) + set(ENV_OPENCLROOT "$ENV{INTELOCLSDKROOT}") +endif(ENV_INTELOCLSDKROOT) + +set(ENV_OPENCLROOT2 "$ENV{OPENCLROOT}") +if(ENV_OPENCLROOT2) + set(ENV_OPENCLROOT "$ENV{OPENCLROOT}") +endif(ENV_OPENCLROOT2) + +if(ENV_OPENCLROOT) + find_path( + OPENCL_INCLUDE_DIR + NAMES CL/cl.h OpenCL/cl.h + PATHS "${ENV_OPENCLROOT}/include" + #NO_DEFAULT_PATH #uncomment this is you wish to surpress the use of default paths for OpenCL + ) + + if (("${CMAKE_SYSTEM_NAME}" MATCHES "Linux") OR ("${CMAKE_SYSTEM_NAME}" MATCHES "Windows")) + if(CMAKE_SIZEOF_VOID_P EQUAL 4) + set(OPENCL_LIB_SEARCH_PATH + "${OPENCL_LIB_SEARCH_PATH}" + "${ENV_OPENCLROOT}/lib/x86") + else(CMAKE_SIZEOF_VOID_P EQUAL 4) + set(OPENCL_LIB_SEARCH_PATH + "${OPENCL_LIB_SEARCH_PATH}" + "${ENV_OPENCLROOT}/lib/x86_64") + endif(CMAKE_SIZEOF_VOID_P EQUAL 4) + endif(("${CMAKE_SYSTEM_NAME}" MATCHES "Linux") OR ("${CMAKE_SYSTEM_NAME}" MATCHES "Windows")) + find_library( + OPENCL_LIBRARY + NAMES OpenCL + PATHS "${OPENCL_LIB_SEARCH_PATH}" + #NO_DEFAULT_PATH #uncomment this is you wish to surpress the use of default paths for OpenCL + ) +else(ENV_OPENCLROOT) + find_path( + OPENCL_INCLUDE_DIR + NAMES CL/cl.h OpenCL/cl.h + PATHS "${PROJECT_SOURCE_DIR}" #use the CL/ include folder provided with ViennaCL + ) + + find_library( + OPENCL_LIBRARY + NAMES OpenCL + ) +endif(ENV_OPENCLROOT) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args( + OPENCL + DEFAULT_MSG + OPENCL_LIBRARY OPENCL_INCLUDE_DIR + ) + +if(OPENCL_FOUND) + set(OPENCL_INCLUDE_DIRS "${OPENCL_INCLUDE_DIR}") + set(OPENCL_LIBRARIES "${OPENCL_LIBRARY}") +else(OPENCL_FOUND) + set(OPENCL_INCLUDE_DIRS) + set(OPENCL_LIBRARIES) +endif(OPENCL_FOUND) + +mark_as_advanced( + OPENCL_INCLUDE_DIR + OPENCL_LIBRARY + ) + diff --git a/cmake/Thirdparty/FindViennaCL.cmake b/cmake/Thirdparty/FindViennaCL.cmake index 263c80fdca..f18c4be1d8 100644 --- a/cmake/Thirdparty/FindViennaCL.cmake +++ b/cmake/Thirdparty/FindViennaCL.cmake @@ -1,9 +1,14 @@ # This file is retrieved from caffe/cmake/Modules/FindViennaCL.cmake. +# from the opencl branch on BVLC/Caffe. SET(ViennaCL_WITH_OPENCL TRUE) SET(VIENNACL_INCLUDE_SEARCH_PATHS + viennacl + viennacl-dev .. + ../viennacl + ../viennacl-dev /usr/include /usr/local/include /opt/ViennaCL/include @@ -15,7 +20,7 @@ FIND_PATH(ViennaCL_INCLUDE_DIR NAMES viennacl/forwards.h PATHS ${VIENNACL_INCLUD SET(ViennaCL_FOUND ON) -# Check include files +# Check include files IF(NOT ViennaCL_INCLUDE_DIR) SET(ViennaCL_FOUND OFF) MESSAGE(STATUS "Could not find ViennaCL include. Turning ViennaCL_FOUND off") @@ -33,6 +38,10 @@ ENDIF (ViennaCL_FOUND) IF(ViennaCL_WITH_OPENCL) find_package(OpenCL REQUIRED) + IF(NOT OPENCL_INCLUDE_DIRS) + MESSAGE(FATAL_ERROR "Could not find OpenCL include.") + ENDIF() + MESSAGE(STATUS "Found OpenCL include: ${OPENCL_INCLUDE_DIRS}") ENDIF(ViennaCL_WITH_OPENCL) set(ViennaCL_INCLUDE_DIRS ${ViennaCL_INCLUDE_DIR} ${OPENCL_INCLUDE_DIRS}) From 97e962e2795a06cbf2401c23bb54e3b6cecc7c2c Mon Sep 17 00:00:00 2001 From: Tan Li Boon Date: Wed, 11 Jan 2017 21:53:43 +0800 Subject: [PATCH 2/6] Changed function signatures of ComputeCrossEntropy and SoftmaxCrossEntropyBwd to match the newer API. --- src/core/tensor/tensor_math_opencl.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/core/tensor/tensor_math_opencl.h b/src/core/tensor/tensor_math_opencl.h index a209de4a57..a29c7063dc 100644 --- a/src/core/tensor/tensor_math_opencl.h +++ b/src/core/tensor/tensor_math_opencl.h @@ -577,9 +577,9 @@ void GEMM(const bool transA, const bool transB, template <> -void ComputeCrossEntropy(const size_t batchsize, const size_t dim, - const Block *p, const Block *t, Block *loss, - Context *ctx) { +void ComputeCrossEntropy(bool int_target, const size_t batchsize, + const size_t dim, const Block *p, const Block *t, + Block *loss, Context *ctx) { auto ocl_ctx = get_context(ctx->vcl_ctx_id); auto kernel = ocl_ctx.get_kernel("tensor_math_opencl.cl", "clkernel_crossentropy"); @@ -592,7 +592,7 @@ void ComputeCrossEntropy(const size_t batchsize, const size template <> -void SoftmaxCrossEntropyBwd(const size_t batchsize, const size_t dim, +void SoftmaxCrossEntropyBwd(bool int_target, const size_t batchsize, const size_t dim, const Block *p, const Block *t, Block *grad, Context *ctx) { auto ocl_ctx = get_context(ctx->vcl_ctx_id); From 72930c2743a182f18086480a761fa32b0be2ba8c Mon Sep 17 00:00:00 2001 From: Tan Li Boon Date: Sat, 8 Oct 2016 23:00:43 +0800 Subject: [PATCH 3/6] SINGA-257 Convnet Benchmark - Creation of examples/cifar10/benchmark_alexnet.py. - Editing of said file to conform with the alexnet samples found in the convnet-benchmark repository. - Added benchmark_vgg, fixed typos in benchmark_alexnet. - Implemented missing OpenCL device code to make it callable from python. - Added cmakedefine for USE_OPENCL. - Moved benchmarking scripts to their own folder. - Create README for benchmarking. - Added overfeat benchmark. --- examples/benchmark/README.md | 5 + examples/benchmark/benchmark_alexnet.py | 146 ++++++++++++++++++++++ examples/benchmark/benchmark_overfeat.py | 146 ++++++++++++++++++++++ examples/benchmark/benchmark_vgg.py | 149 +++++++++++++++++++++++ include/singa/core/device.h | 40 +++--- python/singa/device.py | 16 +++ python/singa/layer.py | 5 +- python/singa/net.py | 1 + src/api/config.i.in | 1 + src/api/core_device.i | 19 ++- src/core/device/platform.cc | 53 ++++++-- src/core/tensor/tensor_math_opencl.h | 35 ++---- src/model/layer/opencl_convolution.cc | 2 +- src/model/layer/opencl_pooling.cc | 2 +- 14 files changed, 564 insertions(+), 56 deletions(-) create mode 100644 examples/benchmark/README.md create mode 100644 examples/benchmark/benchmark_alexnet.py create mode 100644 examples/benchmark/benchmark_overfeat.py create mode 100644 examples/benchmark/benchmark_vgg.py diff --git a/examples/benchmark/README.md b/examples/benchmark/README.md new file mode 100644 index 0000000000..2f5855d7ff --- /dev/null +++ b/examples/benchmark/README.md @@ -0,0 +1,5 @@ +#Benchmark scripts + +These scripts will create a neural net modelled after the ones specified in [convnet-benchmarks](https://github.com/soumith/convnet-benchmarks/tree/master/caffe/imagenet_winners). + +To run them, create a python pip virtualenv or anaconda virtual environment as guided by [this article](http://singa.apache.org/en/docs/installation.html#pip-and-anaconda-for-pysinga). Then, execute the python scripts in this folder. diff --git a/examples/benchmark/benchmark_alexnet.py b/examples/benchmark/benchmark_alexnet.py new file mode 100644 index 0000000000..89c76bdede --- /dev/null +++ b/examples/benchmark/benchmark_alexnet.py @@ -0,0 +1,146 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +''' This model is created following the structure from +https://github.com/soumith/convnet-benchmarks/blob/master/caffe/imagenet_winners/alexnet.prototxt +''' + +import sys +import timeit +import numpy as np + +from singa import device +from singa import layer +from singa import loss +from singa import metric +from singa import net as ffnet +from singa import tensor +from singa import optimizer +from singa.proto import core_pb2 + +iterations = 10 +batch_size = 128 +input_shape = (3, 224, 224) + +def create_net(use_cpu = False, use_ocl = False): + if use_cpu: + layer.engine = 'singacpp' + if use_ocl: + layer.engine = 'singacl' + + net = ffnet.FeedForwardNet(loss.SoftmaxCrossEntropy(), metric.Accuracy()) + + # Conv 1 + net.add(layer.Conv2D("conv1", 64, 11, 4, pad=2, input_sample_shape=input_shape)) + net.add(layer.Activation("conv1/relu" )) + net.add(layer.MaxPooling2D("pool1/3x3_s2", 3, 2)) + + # Conv 2 + net.add(layer.Conv2D("conv1/5x5_s1", 192, 5, 1, pad=2)) + net.add(layer.Activation("conv2/relu")) + net.add(layer.MaxPooling2D("pool2/3x3_s2", 3, 2)) + + # Conv 3 + net.add(layer.Conv2D("conv3/3x3_s1", 384, 3, 1, pad=1)) + net.add(layer.Activation("conv3/relu")) + + # Conv 4 + net.add(layer.Conv2D("conv4/3x3_s1", 256, 3, 1, pad=1)) + net.add(layer.Activation("conv4/relu")) + + # Conv 5 + net.add(layer.Conv2D("conv5/3x3_s1", 256, 3, 1, pad=1)) + net.add(layer.Activation("conv5/relu")) + net.add(layer.MaxPooling2D("pool5/3x3_s2", 3, 2)) + + # L2 Norm -> Inner product + net.add(layer.Flatten("flat")) + net.add(layer.Dense("fc6", 4096)) + net.add(layer.Activation("fc6/relu6")) + + net.add(layer.Dense("fc7", 4096)) + net.add(layer.Activation("fc7/relu7")) + + net.add(layer.Dense("fc8", 1000)) + + for (val, spec) in zip(net.param_values(), net.param_specs()): + filler = spec.filler + if filler.type == 'gaussian': + val.gaussian(filler.mean, filler.std) + else: + val.set_value(0) + print spec.name, filler.type, val.l1() + + return net + +# Time forward, backward, parameter update, per layer (1x forward, 1x backward) +def train(net, dev): + tx = tensor.Tensor((batch_size,) + input_shape, dev) + ty = tensor.Tensor((batch_size,), dev) # Should be integers, but CUDA with int tensor is not supported yet + tx.gaussian(1.0, 0.5) + ty.set_value(0.0) + + opt = optimizer.SGD(momentum=0.9) + idx = np.arange(tx.shape[0], dtype = np.int32) + loss = 0.0 + acc = 0.0 + + train_time = 0.0 + update_time = 0.0 + for b in range(iterations): + + t0 = timeit.default_timer() + grads, (l, a) = net.train(tx, ty) + t1 = timeit.default_timer() + t1 -= t0 + train_time += t1 + + loss += l + acc += a + + t2 = timeit.default_timer() + for (s, p, g) in zip(net.param_names(), net.param_values(), grads): + opt.apply_with_lr(0, 0.01, g, p, str(s), b) + t3 = timeit.default_timer() + t3 -= t2 + update_time += t3 + + print("Iteration {}: Train: {}, Update: {}".format(b, t1, t3)) + + print("Total iterations: {}".format(iterations)) + print("Average training time: {}".format(train_time/iterations)) + print("Average update time: {}".format(update_time/iterations)) + +if __name__ == '__main__': + if len(sys.argv) != 2: + print("Pass in one argument of 'cpu', 'cuda', or 'opencl'.") + quit() + + system = sys.argv[1] + print("Running on {}.".format(system)) + + if system == 'cpu': + net = create_net(True, False) + dev = device.get_default_device() + elif system == 'cuda': + net = create_net(False, False) + dev = device.create_cuda_gpu() + elif system == 'opencl': + net = create_net(False, True) + dev = device.create_opencl_device() + + net.to_device(dev) + train(net, dev) diff --git a/examples/benchmark/benchmark_overfeat.py b/examples/benchmark/benchmark_overfeat.py new file mode 100644 index 0000000000..fa4ef964e6 --- /dev/null +++ b/examples/benchmark/benchmark_overfeat.py @@ -0,0 +1,146 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +''' This model is created following the structure from +https://github.com/soumith/convnet-benchmarks/blob/master/caffe/imagenet_winners/overfeat.prototxt +''' + +import sys +import timeit +import numpy as np + +from singa import device +from singa import layer +from singa import loss +from singa import metric +from singa import net as ffnet +from singa import tensor +from singa import optimizer +from singa.proto import core_pb2 + +iterations = 10 +batch_size = 128 +input_shape = (3, 231, 231) + +def create_net(use_cpu = False, use_ocl = False): + if use_cpu: + layer.engine = 'singacpp' + if use_ocl: + layer.engine = 'singacl' + + net = ffnet.FeedForwardNet(loss.SoftmaxCrossEntropy(), metric.Accuracy()) + + # Conv 1 + net.add(layer.Conv2D("conv1", 96, 11, 4, input_sample_shape=input_shape)) + net.add(layer.Activation("conv1/relu" )) + net.add(layer.MaxPooling2D("pool1/2x2_s2", 2, 2, border_mode='valid')) + + # Conv 2 + net.add(layer.Conv2D("conv1/5x5_s1", 256, 5, 1)) + net.add(layer.Activation("conv2/relu")) + net.add(layer.MaxPooling2D("pool2/2x2_s2", 2, 2, border_mode='valid')) + + # Conv 3 + net.add(layer.Conv2D("conv3/3x3_s1", 512, 3, 1, pad=1)) + net.add(layer.Activation("conv3/relu")) + + # Conv 4 + net.add(layer.Conv2D("conv4/3x3_s1", 1024, 3, 1, pad=1)) + net.add(layer.Activation("conv4/relu")) + + # Conv 5 + net.add(layer.Conv2D("conv5/3x3_s1", 1024, 3, 1, pad=1)) + net.add(layer.Activation("conv5/relu")) + net.add(layer.MaxPooling2D("pool5/2x2_s2", 2, 2, border_mode='valid')) + + # L2 Norm -> Inner product + net.add(layer.Flatten("flat")) + net.add(layer.Dense("fc6", 3072)) + net.add(layer.Activation("fc6/relu6")) + + net.add(layer.Dense("fc7", 4096)) + net.add(layer.Activation("fc7/relu7")) + + net.add(layer.Dense("fc8", 1000)) + + for (val, spec) in zip(net.param_values(), net.param_specs()): + filler = spec.filler + if filler.type == 'gaussian': + val.gaussian(filler.mean, filler.std) + else: + val.set_value(0) + print spec.name, filler.type, val.l1() + + return net + +# Time forward, backward, parameter update, per layer (1x forward, 1x backward) +def train(net, dev): + tx = tensor.Tensor((batch_size,) + input_shape, dev) + ty = tensor.Tensor((batch_size,), dev) # Should be integers, but CUDA with int tensor is not supported yet + tx.gaussian(1.0, 0.5) + ty.set_value(0.0) + + opt = optimizer.SGD(momentum=0.9) + idx = np.arange(tx.shape[0], dtype = np.int32) + loss = 0.0 + acc = 0.0 + + train_time = 0.0 + update_time = 0.0 + for b in range(iterations): + + t0 = timeit.default_timer() + grads, (l, a) = net.train(tx, ty) + t1 = timeit.default_timer() + t1 -= t0 + train_time += t1 + + loss += l + acc += a + + t2 = timeit.default_timer() + for (s, p, g) in zip(net.param_names(), net.param_values(), grads): + opt.apply_with_lr(0, 0.01, g, p, str(s), b) + t3 = timeit.default_timer() + t3 -= t2 + update_time += t3 + + print("Iteration {}: Train: {}, Update: {}".format(b, t1, t3)) + + print("Total iterations: {}".format(iterations)) + print("Average training time: {}".format(train_time/iterations)) + print("Average update time: {}".format(update_time/iterations)) + +if __name__ == '__main__': + if len(sys.argv) != 2: + print("Pass in one argument of 'cpu', 'cuda', or 'opencl'.") + quit() + + system = sys.argv[1] + print("Running on {}.".format(system)) + + if system == 'cpu': + net = create_net(True, False) + dev = device.get_default_device() + elif system == 'cuda': + net = create_net(False, False) + dev = device.create_cuda_gpu() + elif system == 'opencl': + net = create_net(False, True) + dev = device.create_opencl_device() + + net.to_device(dev) + train(net, dev) diff --git a/examples/benchmark/benchmark_vgg.py b/examples/benchmark/benchmark_vgg.py new file mode 100644 index 0000000000..03b233ef73 --- /dev/null +++ b/examples/benchmark/benchmark_vgg.py @@ -0,0 +1,149 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +''' This model is created following the structure from +https://github.com/soumith/convnet-benchmarks/blob/master/caffe/imagenet_winners/vgg_a.prototxt +''' + +import sys +import timeit +import numpy as np + +from singa import device +from singa import layer +from singa import loss +from singa import metric +from singa import net as ffnet +from singa import tensor +from singa import optimizer +from singa.proto import core_pb2 + +iterations = 10 +batch_size = 64 +input_shape = (3, 224, 224) + +def create_net(use_cpu = False, use_ocl = False): + if use_cpu: + layer.engine = 'singacpp' + if use_ocl: + layer.engine = 'singacl' + + net = ffnet.FeedForwardNet(loss.SoftmaxCrossEntropy(), metric.Accuracy()) + + net.add(layer.Conv2D("conv1/3x3_s1", 64, 3, 1, pad=1, input_sample_shape=input_shape)) + net.add(layer.Activation("conv1/relu")) + net.add(layer.MaxPooling2D("pool1/2x2_s2", 2, 2, border_mode='valid')) + + net.add(layer.Conv2D("conv2/3x3_s1", 128, 3, 1, pad=1)) + net.add(layer.Activation("conv2/relu")) + net.add(layer.MaxPooling2D("pool2/2x2_s2", 2, 2, border_mode='valid')) + + net.add(layer.Conv2D("conv3/3x3_s1", 256, 3, 1, pad=1)) + net.add(layer.Activation("conv3/relu")) + # No pooling layer here. + + net.add(layer.Conv2D("conv4/3x3_s1", 256, 3, 1, pad=1)) + net.add(layer.Activation("conv4/relu")) + net.add(layer.MaxPooling2D("pool3/2x2_s2", 2, 2, border_mode='valid')) + + net.add(layer.Conv2D("conv5/3x3_s1", 512, 3, 1, pad=1)) + net.add(layer.Activation("conv5/relu")) + # No pooling layer here. + + net.add(layer.Conv2D("conv6/3x3_s1", 512, 3, 1, pad=1)) + net.add(layer.Activation("conv6/relu")) + net.add(layer.MaxPooling2D("pool4/2x2_s2", 2, 2, border_mode='valid')) + + net.add(layer.Conv2D("conv7/3x3_s1", 512, 3, 1, pad=1)) + net.add(layer.Activation("conv7/relu")) + # No pooling layer here. + + net.add(layer.Conv2D("conv8/3x3_s1", 512, 3, 1, pad=1)) + net.add(layer.Activation("conv8/relu")) + net.add(layer.MaxPooling2D("pool5/2x2_s2", 2, 2, border_mode='valid')) + + net.add(layer.Flatten('flat')) + net.add(layer.Dense("fc6", 4096)) + net.add(layer.Dense("fc7", 4096)) + net.add(layer.Dense("fc8", 1000)) + + for (val, spec) in zip(net.param_values(), net.param_specs()): + filler = spec.filler + if filler.type == 'gaussian': + val.gaussian(filler.mean, filler.std) + else: + val.set_value(0) + print spec.name, filler.type, val.l1() + + return net + +def train(net, dev): + tx = tensor.Tensor((batch_size,) + input_shape, dev) + ty = tensor.Tensor((batch_size,), dev) # Should be integers, but CUDA with int tensor is not supported yet + tx.gaussian(1.0, 0.5) + ty.set_value(0.0) + + opt = optimizer.SGD(momentum=0.9) + idx = np.arange(tx.shape[0], dtype = np.int32) + loss = 0.0 + acc = 0.0 + + train_time = 0.0 + update_time = 0.0 + for b in range(iterations): + + t0 = timeit.default_timer() + grads, (l, a) = net.train(tx, ty) + t1 = timeit.default_timer() + t1 -= t0 + train_time += t1 + + loss += l + acc += a + + t2 = timeit.default_timer() + for (s, p, g) in zip(net.param_names(), net.param_values(), grads): + opt.apply_with_lr(0, 0.01, g, p, str(s), b) + t3 = timeit.default_timer() + t3 -= t2 + update_time += t3 + + print("Iteration {}: Train: {}, Update: {}".format(b, t1, t3)) + + print("Total iterations: {}".format(iterations)) + print("Average training time: {}".format(train_time/iterations)) + print("Average update time: {}".format(update_time/iterations)) + +if __name__ == '__main__': + if len(sys.argv) != 2: + print("Pass in one argument of 'cpu', 'cuda', or 'opencl'.") + quit() + + system = sys.argv[1] + print("Running on {}.".format(system)) + + if system == 'cpu': + net = create_net(True, False) + dev = device.get_default_device() + elif system == 'cuda': + net = create_net(False, False) + dev = device.create_cuda_gpu() + elif system == 'opencl': + net = create_net(False, True) + dev = device.create_opencl_device() + + net.to_device(dev) + train(net, dev) diff --git a/include/singa/core/device.h b/include/singa/core/device.h index 0fecc6d8f2..06993ab488 100644 --- a/include/singa/core/device.h +++ b/include/singa/core/device.h @@ -50,7 +50,7 @@ namespace singa { /// There are three types of devices distinguished by their programming /// languages, namely cpp, cuda and opencl. class Device { - public: +public: // Device() = default; virtual ~Device() {} /// Constructor with device ID, num of executors (e.g., cuda streams), @@ -102,10 +102,10 @@ class Device { int id() const { return id_; } - private: +private: Device() {}; - protected: +protected: /// Execute one operation on one executor. virtual void DoExec(function&& fn, int executor) = 0; @@ -118,7 +118,7 @@ class Device { /// Free device memory. virtual void Free(void* ptr) = 0; - protected: +protected: int id_ = 0; int num_executors_ = 0; unsigned seed_ = 0; @@ -140,14 +140,14 @@ extern std::shared_ptr defaultDevice; /// Represent a CPU device which may have multiple threads/executors. /// It runs cpp code. class CppCPU : public Device { - public: +public: ~CppCPU() {}; CppCPU(); std::shared_ptr host() const override { return defaultDevice;} void SetRandSeed(unsigned seed) override; - protected: +protected: void DoExec(function&& fn, int executor) override; void CopyToFrom(void* dst, const void* src, size_t nBytes, @@ -167,7 +167,7 @@ class CppCPU : public Device { #ifdef USE_CUDA // Represent a Nvidia GPU which runs cuda code. class CudaGPU : public Device { - public: +public: ~CudaGPU(); /// Construct the device using default mem pool setting. CudaGPU(int id = 0); @@ -177,7 +177,7 @@ class CudaGPU : public Device { void SetRandSeed(unsigned seed) override; size_t GetAllocatedMem() override; - protected: +protected: void DoExec(function&& fn, int executor) override; void CopyToFrom(void* dst, const void* src, size_t nBytes, @@ -189,10 +189,10 @@ class CudaGPU : public Device { /// Free cpu memory. void Free(void* ptr) override; - private: +private: void Setup(); - private: +private: shared_ptr pool_; }; @@ -292,20 +292,30 @@ class Platform { CreateCudaGPUsOn(const std::vector &devices, size_t init_size = 0); #endif // USE_CUDA +#ifdef USE_OPENCL + + const int GetNumOCLPlatforms(); + + const int GetNumOCLDevices(); + + static const std::shared_ptr GetDefaultOCLDevice(); + /// Create a \p num_devices set of valid OpenCL devices, regardless of /// platforms. If there are fewer valid devices than requested, then this - /// method will return as many as possible.If OpenCL is not in use, this + /// method will return as many as possible. If OpenCL is not in use, this /// method will return an empty array. - const std::vector > CreateOpenclDevices( - const size_t num_devices); +// static const std::vector> +// CreateOCLDevices(const size_t num_devices); /// Create a set of valid OpenCL devices, regardless of platforms, assigning /// \p id to each device in sequence. /// If there are fewer valid devices than requested, then this method will /// return as many as possible. /// If OpenCL is not in use, this method will return an empty array. - const std::vector > - CreateOpenclDevices(const vector &id); +// static const std::vector> +// CreateOCLDevices(const std::vector &id); + +#endif // USE_OPENCL /// This function is implementd by Caffe (http://caffe.berkeleyvision.org/). /// This function checks the availability of GPU #device_id. diff --git a/python/singa/device.py b/python/singa/device.py index f250f9e6af..749db4c124 100644 --- a/python/singa/device.py +++ b/python/singa/device.py @@ -120,6 +120,22 @@ def create_cuda_gpu_on(device_id): return devices[0] +def get_num_ocl_platforms(): + return singa.Platform.GetNumOCLPlatforms() + +def get_num_ocl_devices(): + return singa.Platform.GetNumOCLDevices() + + +def create_opencl_device(): + '''Create the default OpenCL device. + + Returns: + a swig converted OpenCL device. + ''' + return singa.Platform.GetDefaultOCLDevice() + + default_device = singa.Platform.GetDefaultDevice() diff --git a/python/singa/layer.py b/python/singa/layer.py index 583126a287..8a75161835 100644 --- a/python/singa/layer.py +++ b/python/singa/layer.py @@ -75,6 +75,7 @@ class Layer(object): 1. construct layer without input_sample_shapes, goto 2; construct layer with input_sample_shapes, goto 3; 2. call setup to create the parameters and setup other meta fields +w 3. call forward or access layer members 4. call backward and get parameters for update @@ -350,7 +351,7 @@ def __init__(self, name, nb_kernels, kernel=3, stride=1, border_mode='same', self.conf.param.extend([bspecs]) self.param_specs.append(bspecs) - _check_engine(engine, ['cudnn', 'singacpp']) + _check_engine(engine, ['cudnn', 'singacpp', 'singacl']) self.layer = _create_layer(engine, 'Convolution') if input_sample_shape is not None: self.setup(input_sample_shape) @@ -407,7 +408,7 @@ def __init__(self, name, mode, kernel=3, stride=2, border_mode='same', conf = self.conf.pooling_conf conf = _set_kernel_stride_pad(conf, kernel, stride, border_mode, pad) conf.pool = mode - _check_engine(engine, ['cudnn', 'singacpp']) + _check_engine(engine, ['cudnn', 'singacpp', 'singacl']) self.layer = _create_layer(engine, 'Pooling') if input_sample_shape is not None: self.setup(input_sample_shape) diff --git a/python/singa/net.py b/python/singa/net.py index 36c70f8aea..117ac00985 100644 --- a/python/singa/net.py +++ b/python/singa/net.py @@ -28,6 +28,7 @@ '''For display training information, e.g L1 value of layer data''' verbose = False +#benchmark = True class FeedForwardNet(object): diff --git a/src/api/config.i.in b/src/api/config.i.in index cea35171d1..05ddf6ed50 100644 --- a/src/api/config.i.in +++ b/src/api/config.i.in @@ -1,6 +1,7 @@ // Pass in cmake configurations to swig #cmakedefine01 USE_CUDA #cmakedefine01 USE_CUDNN +#cmakedefine01 USE_OPENCL #cmakedefine01 USE_PYTHON #cmakedefine01 USE_JAVA #cmakedefine CUDNN_VERSION ${CUDNN_VERSION} diff --git a/src/api/core_device.i b/src/api/core_device.i index a9bb840cb3..04d028afe3 100644 --- a/src/api/core_device.i +++ b/src/api/core_device.i @@ -44,25 +44,38 @@ namespace std{ namespace singa{ class Device { - public: +public: virtual void SetRandSeed(unsigned seed) = 0; std::shared_ptr host(); int id() const; }; class Platform { - public: +public: #if USE_CUDA static int GetNumGPUs(); static const std::vector GetGPUIDs(); static const std::pair GetGPUMemSize(const int device); static const std::vector> GetGPUMemSize(); static const std::string DeviceQuery(int id, bool verbose = false); - static const std::vector > + static const std::vector> CreateCudaGPUs(const size_t num_devices, size_t init_size = 0); static const std::vector> CreateCudaGPUsOn(const std::vector &devices, size_t init_size = 0); #endif // USE_CUDA + +#if USE_OPENCL + + const int GetNumOCLPlatforms(); + const int GetNumOCLDevices(); + static const std::shared_ptr GetDefaultOCLDevice(); +// static const std::vector> +// CreateOpenclDevices(const size_t num_devices); +// static const std::vector> +// CreateOpenclDevices(); + +#endif // USE_OPENCL + static std::shared_ptr GetDefaultDevice(); }; diff --git a/src/core/device/platform.cc b/src/core/device/platform.cc index eb02c5bb1e..6de0b396ba 100644 --- a/src/core/device/platform.cc +++ b/src/core/device/platform.cc @@ -19,11 +19,12 @@ #include "singa/core/device.h" #include "singa/singa_config.h" - -#ifdef USE_CUDA +#include "singa/utils/opencl_utils.h" namespace singa { +#ifdef USE_CUDA + int Platform::GetNumGPUs() { int count; CUDA_CHECK(cudaGetDeviceCount(&count)); @@ -109,7 +110,7 @@ const string Platform::DeviceQuery(int device, bool verbose) { return out.str(); } -const vector > +const vector> Platform::CreateCudaGPUs(const size_t num_devices, size_t init_size) { const vector gpus = GetGPUIDs(); CHECK_LE(num_devices, gpus.size()); @@ -117,7 +118,7 @@ Platform::CreateCudaGPUs(const size_t num_devices, size_t init_size) { return CreateCudaGPUsOn(use_gpus, init_size); } -const vector > +const vector> Platform::CreateCudaGPUsOn(const vector &devices, size_t init_size) { MemPoolConf conf; if (init_size > 0) @@ -137,8 +138,46 @@ Platform::CreateCudaGPUsOn(const vector &devices, size_t init_size) { return ret; } -} // namespace singa - #endif // USE_CUDA -#endif \ No newline at end of file +#ifdef USE_OPENCL + +const int Platform::GetNumOCLPlatforms() { + auto all_platforms = viennacl::ocl::get_platforms(); + return (int)all_platforms.size(); +} + +const int Platform::GetNumOCLDevices() { + auto all_platforms = viennacl::ocl::get_platforms(); + unsigned int total_num_devices = 0; + for (auto plat : all_platforms) { + auto all_devices = plat.devices(CL_DEVICE_TYPE_ALL); + total_num_devices += all_devices.size(); + } + return (int)total_num_devices; +} + +const std::shared_ptr Platform::GetDefaultOCLDevice() { + return std::make_shared(); +} +/* +static const std::vector> +Platform::CreateOCLDevices(const size_t num_devices) { + auto all_platforms = viennacl::ocl::get_platforms(); + for (auto plat : all_platforms) { + auto all_devices = plat.devices(CL_DEVICE_TYPE_ALL); + total_num_devices += all_devices.size(); + } + return (int)total_num_devices; +} + +static const std::vector> +Platform::CreateOCLDevices(const std::vector &id) { + +} +*/ +#endif // USE_OPENCL + +} // namespace singa + +#endif diff --git a/src/core/tensor/tensor_math_opencl.h b/src/core/tensor/tensor_math_opencl.h index a29c7063dc..c939dbbe07 100644 --- a/src/core/tensor/tensor_math_opencl.h +++ b/src/core/tensor/tensor_math_opencl.h @@ -440,36 +440,17 @@ void Amin(const size_t num, const Block* in, size_t* out, C out[0] = temp[0]; delete temp; } - +*/ template<> void Asum(const size_t num, const Block* in, float* out, Context* ctx) { - cl_int status = CL_SUCCESS; - - std::string kname = "clkernel_asum"; - auto kernel = ctx->kernels->at(kname); - - cl::Buffer inbuf = *(static_cast(in->mutable_data())); - - size_t size = sizeof(float) * num; - cl::Buffer outval(ctx->ocl_ctx, CL_MEM_WRITE_ONLY, size, nullptr, &status); - OCL_CHECK(status, "Failed to create buffer!"); - - kernel.setArg(0, (cl_int)num); - kernel.setArg(1, inbuf); - kernel.setArg(2, outval); - kernel.setArg(3, cl::Local(size)); + viennacl::vector v_in((const cl_mem)in->data(), num); - status = ctx->ocl_cmdq.enqueueNDRangeKernel(kernel, cl::NDRange(0), cl::NDRange(num)); - OCL_CHECK(status, "Failed to enqueue kernel function!"); + viennacl::vector temp = viennacl::linalg::element_fabs(v_in); - float* temp = new float[num]; - status = ctx->ocl_cmdq.enqueueReadBuffer(outval, CL_TRUE, 0, size, temp); - OCL_CHECK(status, "Failed to read from buffer!"); - out[0] = temp[0]; - delete temp; + out[0] = viennacl::linalg::sum(temp); } -*/ + /// out = alpha * in + out template<> void Axpy(const size_t num, const float alpha, const Block* in, Block* out, Context* ctx) { @@ -528,7 +509,7 @@ void GEMV(bool trans, const size_t m, const size_t n, const } /// multiply a matrix with a diagonal matrix constructed using values from 'v'. -/// if matrix_lef_side is true, do M*v; else do v*M +/// if matrix_left_side is true, do M*v; else do v*M template<> void DGMM(bool side_right, const size_t nrow, const size_t ncol, @@ -541,9 +522,9 @@ void DGMM(bool side_right, auto diag = viennacl::diag(v_buf); if (side_right) { - out_buf = viennacl::linalg::prod(diag, M_buf); - } else { out_buf = viennacl::linalg::prod(M_buf, diag); + } else { + out_buf = viennacl::linalg::prod(diag, M_buf); } } diff --git a/src/model/layer/opencl_convolution.cc b/src/model/layer/opencl_convolution.cc index c43719ff7a..4b70a714a7 100644 --- a/src/model/layer/opencl_convolution.cc +++ b/src/model/layer/opencl_convolution.cc @@ -22,7 +22,7 @@ namespace singa { -RegisterLayerClass(opencl_convolution, OpenclConvolution); +RegisterLayerClass(singacl_convolution, OpenclConvolution); /// \copydoc Layer::Forward(int flag, const Tensor&) const Tensor OpenclConvolution::Forward(int flag, const Tensor &input) { diff --git a/src/model/layer/opencl_pooling.cc b/src/model/layer/opencl_pooling.cc index 2e3533078f..f123270b9c 100644 --- a/src/model/layer/opencl_pooling.cc +++ b/src/model/layer/opencl_pooling.cc @@ -22,7 +22,7 @@ namespace singa { -RegisterLayerClass(opencl_pooling, OpenclPooling); +RegisterLayerClass(singacl_pooling, OpenclPooling); const Tensor OpenclPooling::Forward(int flag, const Tensor &input) { CHECK(buf_.empty()); From 5857698c06a36478d66c66e90eba26f55fa59834 Mon Sep 17 00:00:00 2001 From: Tan Li Boon Date: Tue, 1 Nov 2016 15:33:28 +0800 Subject: [PATCH 4/6] Added train with benchmarking to python/singa/net.py. --- examples/benchmark/benchmark_alexnet.py | 20 +++++++++---------- examples/benchmark/benchmark_overfeat.py | 20 +++++++++---------- examples/benchmark/benchmark_vgg.py | 25 +++++++++++------------- python/singa/net.py | 20 +++++++++++++++++-- 4 files changed, 47 insertions(+), 38 deletions(-) diff --git a/examples/benchmark/benchmark_alexnet.py b/examples/benchmark/benchmark_alexnet.py index 89c76bdede..6e1ffa71fe 100644 --- a/examples/benchmark/benchmark_alexnet.py +++ b/examples/benchmark/benchmark_alexnet.py @@ -103,26 +103,24 @@ def train(net, dev): for b in range(iterations): t0 = timeit.default_timer() - grads, (l, a) = net.train(tx, ty) - t1 = timeit.default_timer() - t1 -= t0 - train_time += t1 + grads, (l, a) = net.train_benchmark(tx, ty) + t0 = timeit.default_timer() - t0 + train_time += t0 loss += l acc += a - t2 = timeit.default_timer() + t1 = timeit.default_timer() for (s, p, g) in zip(net.param_names(), net.param_values(), grads): opt.apply_with_lr(0, 0.01, g, p, str(s), b) - t3 = timeit.default_timer() - t3 -= t2 - update_time += t3 + t1 = timeit.default_timer() - t1 + update_time += t1 - print("Iteration {}: Train: {}, Update: {}".format(b, t1, t3)) + print("Iteration {}: Train: {}, Update: {}".format(b, round(t0, 4), round(t1, 4))) print("Total iterations: {}".format(iterations)) - print("Average training time: {}".format(train_time/iterations)) - print("Average update time: {}".format(update_time/iterations)) + print("Average training time: {0:.4f}".format(train_time/iterations)) + print("Average update time: {0:.4f}".format(update_time/iterations)) if __name__ == '__main__': if len(sys.argv) != 2: diff --git a/examples/benchmark/benchmark_overfeat.py b/examples/benchmark/benchmark_overfeat.py index fa4ef964e6..b20ecdc58d 100644 --- a/examples/benchmark/benchmark_overfeat.py +++ b/examples/benchmark/benchmark_overfeat.py @@ -103,26 +103,24 @@ def train(net, dev): for b in range(iterations): t0 = timeit.default_timer() - grads, (l, a) = net.train(tx, ty) - t1 = timeit.default_timer() - t1 -= t0 - train_time += t1 + grads, (l, a) = net.train_benchmark(tx, ty) + t0 = timeit.default_timer() - t0 + train_time += t0 loss += l acc += a - t2 = timeit.default_timer() + t1 = timeit.default_timer() for (s, p, g) in zip(net.param_names(), net.param_values(), grads): opt.apply_with_lr(0, 0.01, g, p, str(s), b) - t3 = timeit.default_timer() - t3 -= t2 - update_time += t3 + t1 = timeit.default_timer() - t1 + update_time += t1 - print("Iteration {}: Train: {}, Update: {}".format(b, t1, t3)) + print("Iteration {}: Train: {}, Update: {}".format(b, round(t0, 4), round(t1, 4))) print("Total iterations: {}".format(iterations)) - print("Average training time: {}".format(train_time/iterations)) - print("Average update time: {}".format(update_time/iterations)) + print("Average training time: {0:.4f}".format(train_time/iterations)) + print("Average update time: {0:.4f}".format(update_time/iterations)) if __name__ == '__main__': if len(sys.argv) != 2: diff --git a/examples/benchmark/benchmark_vgg.py b/examples/benchmark/benchmark_vgg.py index 03b233ef73..507e62464a 100644 --- a/examples/benchmark/benchmark_vgg.py +++ b/examples/benchmark/benchmark_vgg.py @@ -102,30 +102,27 @@ def train(net, dev): acc = 0.0 train_time = 0.0 - update_time = 0.0 for b in range(iterations): - + t0 = timeit.default_timer() - grads, (l, a) = net.train(tx, ty) - t1 = timeit.default_timer() - t1 -= t0 - train_time += t1 + grads, (l, a) = net.train_benchmark(tx, ty) + t0 = timeit.default_timer() - t0 + train_time += t0 loss += l acc += a - t2 = timeit.default_timer() + t1 = timeit.default_timer() for (s, p, g) in zip(net.param_names(), net.param_values(), grads): opt.apply_with_lr(0, 0.01, g, p, str(s), b) - t3 = timeit.default_timer() - t3 -= t2 - update_time += t3 - - print("Iteration {}: Train: {}, Update: {}".format(b, t1, t3)) + t1 = timeit.default_timer() - t1 + update_time += t1 + + print("Iteration {}: Train: {}, Update: {}".format(b, round(t0, 4), round(t1, 4))) print("Total iterations: {}".format(iterations)) - print("Average training time: {}".format(train_time/iterations)) - print("Average update time: {}".format(update_time/iterations)) + print("Average training time: {0:.4f}".format(train_time/iterations)) + print("Average update time: {0:.4f}".format(update_time/iterations)) if __name__ == '__main__': if len(sys.argv) != 2: diff --git a/python/singa/net.py b/python/singa/net.py index 117ac00985..7824c1cc83 100644 --- a/python/singa/net.py +++ b/python/singa/net.py @@ -19,7 +19,7 @@ functions for net info, e.g., parameters. """ - +import timeit from .proto.model_pb2 import kTrain, kEval import tensor import layer @@ -28,7 +28,7 @@ '''For display training information, e.g L1 value of layer data''' verbose = False -#benchmark = True +benchmark = True class FeedForwardNet(object): @@ -134,6 +134,22 @@ def train(self, x, y): if self.metric is not None: m = self.metric.evaluate(out, y) return self.backward(), (l.l1(), m) + + def train_benchmark(self, x, y): + t0 = timeit.default_timer() + out = self.forward(kTrain, x) + t0 = timeit.default_timer() - t0 + + l = self.loss.forward(kTrain, out, y) + if self.metric is not None: + m = self.metric.evaluate(out, y) + + t1 = timeit.default_timer() + grads = self.backward() + t1 = timeit.default_timer() - t1 + + print("Forward: {0:.4f}\tBackward: {0:.4f}".format(t0, t1)) + return grads, (l.l1(), m) def evaluate(self, x, y): '''Evaluate the loss and metric of the given data. From 8d92d02a3a47e7cd31f3b26acd056bd856f53b42 Mon Sep 17 00:00:00 2001 From: Wei Wang Date: Thu, 17 Nov 2016 23:34:16 +0800 Subject: [PATCH 5/6] 1. move net definition into separate files, and merge the benchmark of three model to share the same training code, i.e. run.sh 2. add arrays in net.py to record the time information 3. rename ocl into opencl to be consistent with other functions/variables --- .../{benchmark_alexnet.py => alexnet.py} | 89 +++-------------- .../{benchmark_overfeat.py => overfeat.py} | 86 ++-------------- examples/benchmark/run.py | 97 +++++++++++++++++++ .../benchmark/{benchmark_vgg.py => vgg.py} | 93 +++--------------- include/singa/core/device.h | 69 +++++++------ python/singa/device.py | 13 +-- python/singa/net.py | 92 +++++++++++++----- src/api/core_device.i | 10 +- src/core/device/platform.cc | 12 +-- 9 files changed, 251 insertions(+), 310 deletions(-) rename examples/benchmark/{benchmark_alexnet.py => alexnet.py} (53%) rename examples/benchmark/{benchmark_overfeat.py => overfeat.py} (55%) create mode 100644 examples/benchmark/run.py rename examples/benchmark/{benchmark_vgg.py => vgg.py} (57%) diff --git a/examples/benchmark/benchmark_alexnet.py b/examples/benchmark/alexnet.py similarity index 53% rename from examples/benchmark/benchmark_alexnet.py rename to examples/benchmark/alexnet.py index 6e1ffa71fe..48c812ab53 100644 --- a/examples/benchmark/benchmark_alexnet.py +++ b/examples/benchmark/alexnet.py @@ -18,62 +18,52 @@ https://github.com/soumith/convnet-benchmarks/blob/master/caffe/imagenet_winners/alexnet.prototxt ''' -import sys -import timeit -import numpy as np - -from singa import device from singa import layer from singa import loss from singa import metric from singa import net as ffnet -from singa import tensor -from singa import optimizer -from singa.proto import core_pb2 -iterations = 10 -batch_size = 128 -input_shape = (3, 224, 224) -def create_net(use_cpu = False, use_ocl = False): +def create_net(input_shape, use_cpu=False, use_ocl=False): if use_cpu: layer.engine = 'singacpp' if use_ocl: layer.engine = 'singacl' net = ffnet.FeedForwardNet(loss.SoftmaxCrossEntropy(), metric.Accuracy()) - + # Conv 1 - net.add(layer.Conv2D("conv1", 64, 11, 4, pad=2, input_sample_shape=input_shape)) + net.add(layer.Conv2D("conv1", 64, 11, 4, pad=2, + input_sample_shape=input_shape)) net.add(layer.Activation("conv1/relu" )) net.add(layer.MaxPooling2D("pool1/3x3_s2", 3, 2)) - + # Conv 2 net.add(layer.Conv2D("conv1/5x5_s1", 192, 5, 1, pad=2)) net.add(layer.Activation("conv2/relu")) net.add(layer.MaxPooling2D("pool2/3x3_s2", 3, 2)) - + # Conv 3 net.add(layer.Conv2D("conv3/3x3_s1", 384, 3, 1, pad=1)) net.add(layer.Activation("conv3/relu")) - + # Conv 4 net.add(layer.Conv2D("conv4/3x3_s1", 256, 3, 1, pad=1)) net.add(layer.Activation("conv4/relu")) - + # Conv 5 net.add(layer.Conv2D("conv5/3x3_s1", 256, 3, 1, pad=1)) net.add(layer.Activation("conv5/relu")) net.add(layer.MaxPooling2D("pool5/3x3_s2", 3, 2)) - + # L2 Norm -> Inner product net.add(layer.Flatten("flat")) net.add(layer.Dense("fc6", 4096)) net.add(layer.Activation("fc6/relu6")) - + net.add(layer.Dense("fc7", 4096)) net.add(layer.Activation("fc7/relu7")) - + net.add(layer.Dense("fc8", 1000)) for (val, spec) in zip(net.param_values(), net.param_specs()): @@ -85,60 +75,3 @@ def create_net(use_cpu = False, use_ocl = False): print spec.name, filler.type, val.l1() return net - -# Time forward, backward, parameter update, per layer (1x forward, 1x backward) -def train(net, dev): - tx = tensor.Tensor((batch_size,) + input_shape, dev) - ty = tensor.Tensor((batch_size,), dev) # Should be integers, but CUDA with int tensor is not supported yet - tx.gaussian(1.0, 0.5) - ty.set_value(0.0) - - opt = optimizer.SGD(momentum=0.9) - idx = np.arange(tx.shape[0], dtype = np.int32) - loss = 0.0 - acc = 0.0 - - train_time = 0.0 - update_time = 0.0 - for b in range(iterations): - - t0 = timeit.default_timer() - grads, (l, a) = net.train_benchmark(tx, ty) - t0 = timeit.default_timer() - t0 - train_time += t0 - - loss += l - acc += a - - t1 = timeit.default_timer() - for (s, p, g) in zip(net.param_names(), net.param_values(), grads): - opt.apply_with_lr(0, 0.01, g, p, str(s), b) - t1 = timeit.default_timer() - t1 - update_time += t1 - - print("Iteration {}: Train: {}, Update: {}".format(b, round(t0, 4), round(t1, 4))) - - print("Total iterations: {}".format(iterations)) - print("Average training time: {0:.4f}".format(train_time/iterations)) - print("Average update time: {0:.4f}".format(update_time/iterations)) - -if __name__ == '__main__': - if len(sys.argv) != 2: - print("Pass in one argument of 'cpu', 'cuda', or 'opencl'.") - quit() - - system = sys.argv[1] - print("Running on {}.".format(system)) - - if system == 'cpu': - net = create_net(True, False) - dev = device.get_default_device() - elif system == 'cuda': - net = create_net(False, False) - dev = device.create_cuda_gpu() - elif system == 'opencl': - net = create_net(False, True) - dev = device.create_opencl_device() - - net.to_device(dev) - train(net, dev) diff --git a/examples/benchmark/benchmark_overfeat.py b/examples/benchmark/overfeat.py similarity index 55% rename from examples/benchmark/benchmark_overfeat.py rename to examples/benchmark/overfeat.py index b20ecdc58d..ca7a99d8fc 100644 --- a/examples/benchmark/benchmark_overfeat.py +++ b/examples/benchmark/overfeat.py @@ -18,62 +18,51 @@ https://github.com/soumith/convnet-benchmarks/blob/master/caffe/imagenet_winners/overfeat.prototxt ''' -import sys -import timeit -import numpy as np - -from singa import device from singa import layer from singa import loss from singa import metric from singa import net as ffnet -from singa import tensor -from singa import optimizer -from singa.proto import core_pb2 -iterations = 10 -batch_size = 128 -input_shape = (3, 231, 231) -def create_net(use_cpu = False, use_ocl = False): +def create_net(input_shape, use_cpu=False, use_ocl=False): if use_cpu: layer.engine = 'singacpp' if use_ocl: layer.engine = 'singacl' net = ffnet.FeedForwardNet(loss.SoftmaxCrossEntropy(), metric.Accuracy()) - + # Conv 1 net.add(layer.Conv2D("conv1", 96, 11, 4, input_sample_shape=input_shape)) net.add(layer.Activation("conv1/relu" )) net.add(layer.MaxPooling2D("pool1/2x2_s2", 2, 2, border_mode='valid')) - + # Conv 2 net.add(layer.Conv2D("conv1/5x5_s1", 256, 5, 1)) net.add(layer.Activation("conv2/relu")) net.add(layer.MaxPooling2D("pool2/2x2_s2", 2, 2, border_mode='valid')) - + # Conv 3 net.add(layer.Conv2D("conv3/3x3_s1", 512, 3, 1, pad=1)) net.add(layer.Activation("conv3/relu")) - + # Conv 4 net.add(layer.Conv2D("conv4/3x3_s1", 1024, 3, 1, pad=1)) net.add(layer.Activation("conv4/relu")) - + # Conv 5 net.add(layer.Conv2D("conv5/3x3_s1", 1024, 3, 1, pad=1)) net.add(layer.Activation("conv5/relu")) net.add(layer.MaxPooling2D("pool5/2x2_s2", 2, 2, border_mode='valid')) - + # L2 Norm -> Inner product net.add(layer.Flatten("flat")) net.add(layer.Dense("fc6", 3072)) net.add(layer.Activation("fc6/relu6")) - + net.add(layer.Dense("fc7", 4096)) net.add(layer.Activation("fc7/relu7")) - + net.add(layer.Dense("fc8", 1000)) for (val, spec) in zip(net.param_values(), net.param_specs()): @@ -85,60 +74,3 @@ def create_net(use_cpu = False, use_ocl = False): print spec.name, filler.type, val.l1() return net - -# Time forward, backward, parameter update, per layer (1x forward, 1x backward) -def train(net, dev): - tx = tensor.Tensor((batch_size,) + input_shape, dev) - ty = tensor.Tensor((batch_size,), dev) # Should be integers, but CUDA with int tensor is not supported yet - tx.gaussian(1.0, 0.5) - ty.set_value(0.0) - - opt = optimizer.SGD(momentum=0.9) - idx = np.arange(tx.shape[0], dtype = np.int32) - loss = 0.0 - acc = 0.0 - - train_time = 0.0 - update_time = 0.0 - for b in range(iterations): - - t0 = timeit.default_timer() - grads, (l, a) = net.train_benchmark(tx, ty) - t0 = timeit.default_timer() - t0 - train_time += t0 - - loss += l - acc += a - - t1 = timeit.default_timer() - for (s, p, g) in zip(net.param_names(), net.param_values(), grads): - opt.apply_with_lr(0, 0.01, g, p, str(s), b) - t1 = timeit.default_timer() - t1 - update_time += t1 - - print("Iteration {}: Train: {}, Update: {}".format(b, round(t0, 4), round(t1, 4))) - - print("Total iterations: {}".format(iterations)) - print("Average training time: {0:.4f}".format(train_time/iterations)) - print("Average update time: {0:.4f}".format(update_time/iterations)) - -if __name__ == '__main__': - if len(sys.argv) != 2: - print("Pass in one argument of 'cpu', 'cuda', or 'opencl'.") - quit() - - system = sys.argv[1] - print("Running on {}.".format(system)) - - if system == 'cpu': - net = create_net(True, False) - dev = device.get_default_device() - elif system == 'cuda': - net = create_net(False, False) - dev = device.create_cuda_gpu() - elif system == 'opencl': - net = create_net(False, True) - dev = device.create_opencl_device() - - net.to_device(dev) - train(net, dev) diff --git a/examples/benchmark/run.py b/examples/benchmark/run.py new file mode 100644 index 0000000000..475ba556c4 --- /dev/null +++ b/examples/benchmark/run.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +''' This model is created following the structure from +https://github.com/soumith/convnet-benchmarks/blob/master/caffe/imagenet_winners/alexnet.prototxt +''' + +import sys +import numpy as np + +from singa import device +from singa import tensor +from singa import optimizer + +iterations = 10 +batch_size = 128 +input_shape = (3, 224, 224) + +# Time forward, backward, parameter update, per layer (1x forward, 1x backward) +def train(net, dev): + tx = tensor.Tensor((batch_size,) + input_shape, dev) + ty = tensor.Tensor((batch_size,), dev) + tx.gaussian(1.0, 0.5) + ty.set_value(0.0) + + opt = optimizer.SGD(momentum=0.9) + idx = np.arange(tx.shape[0], dtype = np.int32) + loss = 0.0 + acc = 0.0 + + train_time = 0.0 + update_time = 0.0 + net.start_benchmark() + update = 0 + for b in range(iterations): + grads, (l, a) = net.train_benchmark(tx, ty) + t1 = timer() + for (s, p, g) in zip(net.param_names(), net.param_values(), grads): + opt.apply_with_lr(0, 0.01, g, p, str(s), b) + update += timer() - t1 + t, fp, bp, fps, bps = net.stop_benchmark(iterations) + + print "Total iterations = %d" % iterations + print "Average training time per iteration = %.4f" % t + print "Average forward time per iteration = %.4f" % fp + print "Average backward time per iteration = %.4f" % bp + print "Average udpate time per iteration = %.4f" % (update / iterations) + for (k, v) in fps: + print "Forward time for %10s = %.4f" % (k, v) + for (k, v) in bps: + print "Backward time for %10s = %.4f" % (k, v) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Benchmark SINGA by running' + 'AlexNet/VGG/Overfeat with CPP/CUDA/Opencl') + parser.add_argument('net', choices=['vgg', 'alexnet', 'overfeat'], + default='alexnet') + parser.add_argument('device', choices=['cpp', 'cuda', 'opencl'], + default='cuda') + args = parser.parse_args() + if args.net == 'vgg': + import vgg as model + elif args.net == 'alexnet': + import alexnet as model + else: + assert args.net == 'overfeat', 'Wrong net type:' + args.net + import overfeat as model + + use_cpu = False, + use_opencl = False + + if args.device == 'cpu': + use_cpu = True + dev = device.get_default_device() + elif args.device == 'cuda': + dev = device.create_cuda_gpu() + else: + assert args.device == 'opencl', 'Wrong lang: ' + args.device + use_opencl = True + dev = device.create_opencl_device() + net = model.create_net(input_shape, use_cpu, use_opencl) + net.to_device(dev) + train(net, dev) diff --git a/examples/benchmark/benchmark_vgg.py b/examples/benchmark/vgg.py similarity index 57% rename from examples/benchmark/benchmark_vgg.py rename to examples/benchmark/vgg.py index 507e62464a..0dcdd93493 100644 --- a/examples/benchmark/benchmark_vgg.py +++ b/examples/benchmark/vgg.py @@ -18,68 +18,58 @@ https://github.com/soumith/convnet-benchmarks/blob/master/caffe/imagenet_winners/vgg_a.prototxt ''' -import sys -import timeit -import numpy as np - -from singa import device from singa import layer from singa import loss from singa import metric from singa import net as ffnet -from singa import tensor -from singa import optimizer -from singa.proto import core_pb2 -iterations = 10 -batch_size = 64 -input_shape = (3, 224, 224) -def create_net(use_cpu = False, use_ocl = False): +def create_net(input_shape, use_cpu=False, use_ocl=False): if use_cpu: layer.engine = 'singacpp' if use_ocl: layer.engine = 'singacl' - + net = ffnet.FeedForwardNet(loss.SoftmaxCrossEntropy(), metric.Accuracy()) - - net.add(layer.Conv2D("conv1/3x3_s1", 64, 3, 1, pad=1, input_sample_shape=input_shape)) + + net.add(layer.Conv2D("conv1/3x3_s1", 64, 3, 1, pad=1, + input_sample_shape=input_shape)) net.add(layer.Activation("conv1/relu")) net.add(layer.MaxPooling2D("pool1/2x2_s2", 2, 2, border_mode='valid')) - + net.add(layer.Conv2D("conv2/3x3_s1", 128, 3, 1, pad=1)) net.add(layer.Activation("conv2/relu")) net.add(layer.MaxPooling2D("pool2/2x2_s2", 2, 2, border_mode='valid')) - + net.add(layer.Conv2D("conv3/3x3_s1", 256, 3, 1, pad=1)) net.add(layer.Activation("conv3/relu")) # No pooling layer here. - + net.add(layer.Conv2D("conv4/3x3_s1", 256, 3, 1, pad=1)) net.add(layer.Activation("conv4/relu")) net.add(layer.MaxPooling2D("pool3/2x2_s2", 2, 2, border_mode='valid')) - + net.add(layer.Conv2D("conv5/3x3_s1", 512, 3, 1, pad=1)) net.add(layer.Activation("conv5/relu")) # No pooling layer here. - + net.add(layer.Conv2D("conv6/3x3_s1", 512, 3, 1, pad=1)) net.add(layer.Activation("conv6/relu")) net.add(layer.MaxPooling2D("pool4/2x2_s2", 2, 2, border_mode='valid')) - + net.add(layer.Conv2D("conv7/3x3_s1", 512, 3, 1, pad=1)) net.add(layer.Activation("conv7/relu")) # No pooling layer here. - + net.add(layer.Conv2D("conv8/3x3_s1", 512, 3, 1, pad=1)) net.add(layer.Activation("conv8/relu")) net.add(layer.MaxPooling2D("pool5/2x2_s2", 2, 2, border_mode='valid')) - + net.add(layer.Flatten('flat')) net.add(layer.Dense("fc6", 4096)) net.add(layer.Dense("fc7", 4096)) net.add(layer.Dense("fc8", 1000)) - + for (val, spec) in zip(net.param_values(), net.param_specs()): filler = spec.filler if filler.type == 'gaussian': @@ -89,58 +79,3 @@ def create_net(use_cpu = False, use_ocl = False): print spec.name, filler.type, val.l1() return net - -def train(net, dev): - tx = tensor.Tensor((batch_size,) + input_shape, dev) - ty = tensor.Tensor((batch_size,), dev) # Should be integers, but CUDA with int tensor is not supported yet - tx.gaussian(1.0, 0.5) - ty.set_value(0.0) - - opt = optimizer.SGD(momentum=0.9) - idx = np.arange(tx.shape[0], dtype = np.int32) - loss = 0.0 - acc = 0.0 - - train_time = 0.0 - for b in range(iterations): - - t0 = timeit.default_timer() - grads, (l, a) = net.train_benchmark(tx, ty) - t0 = timeit.default_timer() - t0 - train_time += t0 - - loss += l - acc += a - - t1 = timeit.default_timer() - for (s, p, g) in zip(net.param_names(), net.param_values(), grads): - opt.apply_with_lr(0, 0.01, g, p, str(s), b) - t1 = timeit.default_timer() - t1 - update_time += t1 - - print("Iteration {}: Train: {}, Update: {}".format(b, round(t0, 4), round(t1, 4))) - - print("Total iterations: {}".format(iterations)) - print("Average training time: {0:.4f}".format(train_time/iterations)) - print("Average update time: {0:.4f}".format(update_time/iterations)) - -if __name__ == '__main__': - if len(sys.argv) != 2: - print("Pass in one argument of 'cpu', 'cuda', or 'opencl'.") - quit() - - system = sys.argv[1] - print("Running on {}.".format(system)) - - if system == 'cpu': - net = create_net(True, False) - dev = device.get_default_device() - elif system == 'cuda': - net = create_net(False, False) - dev = device.create_cuda_gpu() - elif system == 'opencl': - net = create_net(False, True) - dev = device.create_opencl_device() - - net.to_device(dev) - train(net, dev) diff --git a/include/singa/core/device.h b/include/singa/core/device.h index 06993ab488..6bb8193d29 100644 --- a/include/singa/core/device.h +++ b/include/singa/core/device.h @@ -50,7 +50,7 @@ namespace singa { /// There are three types of devices distinguished by their programming /// languages, namely cpp, cuda and opencl. class Device { -public: + public: // Device() = default; virtual ~Device() {} /// Constructor with device ID, num of executors (e.g., cuda streams), @@ -118,7 +118,7 @@ class Device { /// Free device memory. virtual void Free(void* ptr) = 0; -protected: + protected: int id_ = 0; int num_executors_ = 0; unsigned seed_ = 0; @@ -140,14 +140,14 @@ extern std::shared_ptr defaultDevice; /// Represent a CPU device which may have multiple threads/executors. /// It runs cpp code. class CppCPU : public Device { -public: + public: ~CppCPU() {}; CppCPU(); std::shared_ptr host() const override { return defaultDevice;} void SetRandSeed(unsigned seed) override; -protected: + protected: void DoExec(function&& fn, int executor) override; void CopyToFrom(void* dst, const void* src, size_t nBytes, @@ -167,7 +167,7 @@ class CppCPU : public Device { #ifdef USE_CUDA // Represent a Nvidia GPU which runs cuda code. class CudaGPU : public Device { -public: + public: ~CudaGPU(); /// Construct the device using default mem pool setting. CudaGPU(int id = 0); @@ -177,7 +177,7 @@ class CudaGPU : public Device { void SetRandSeed(unsigned seed) override; size_t GetAllocatedMem() override; -protected: + protected: void DoExec(function&& fn, int executor) override; void CopyToFrom(void* dst, const void* src, size_t nBytes, @@ -189,10 +189,10 @@ class CudaGPU : public Device { /// Free cpu memory. void Free(void* ptr) override; -private: + private: void Setup(); -private: + private: shared_ptr pool_; }; @@ -203,7 +203,7 @@ class CudaGPU : public Device { #ifdef USE_OPENCL // Implement Device using OpenCL libs. class OpenclDevice : public singa::Device { -public: + public: // TODO: Constructor arguments to consider: // Path to kernel sources? @@ -218,7 +218,7 @@ class OpenclDevice : public singa::Device { CopyDirection direction, int dst_offset = 0, int src_offset = 0) override; -protected: + protected: /// The OpenCL device that this object represents. /// Each OpenclDevice contains exactly one cl::Device for the lifetime of the /// object. @@ -248,7 +248,7 @@ class OpenclDevice : public singa::Device { /// This has the effect of freeing up device memory. void Free(void* ptr) override; -private: + private: static const std::string cl_src_path; }; @@ -260,7 +260,7 @@ class OpenclDevice : public singa::Device { /// return something that indicates their absence (for example, 0 devices); /// however they should always be available regardless of compile-time switches. class Platform { -public: + public: /// Return the defualt host device static std::shared_ptr GetDefaultDevice() { @@ -290,22 +290,39 @@ class Platform { /// Create a set of CudaGPU Device using given GPU IDs. static const std::vector> CreateCudaGPUsOn(const std::vector &devices, size_t init_size = 0); + + /// This function is implementd by Caffe (http://caffe.berkeleyvision.org/). + /// This function checks the availability of GPU #device_id. + /// It attempts to create a context on the device by calling cudaFree(0). + /// cudaSetDevice() alone is not sufficient to check the availability. + /// It lazily records device_id, however, does not initialize a + /// context. So it does not know if the host thread has the permission to use + /// the device or not. + /// + /// In a shared environment where the devices are set to EXCLUSIVE_PROCESS + /// or EXCLUSIVE_THREAD mode, cudaSetDevice() returns cudaSuccess + /// even if the device is exclusively occupied by another process or thread. + /// Cuda operations that initialize the context are needed to check + /// the permission. cudaFree(0) is one of those with no side effect, + /// except the context initialization. + static bool CheckDevice(const int device_id); #endif // USE_CUDA + #ifdef USE_OPENCL - const int GetNumOCLPlatforms(); + const int GetNumOpenclPlatforms(); - const int GetNumOCLDevices(); + const int GetNumOpenclDevices(); - static const std::shared_ptr GetDefaultOCLDevice(); + static const std::shared_ptr GetDefaultOpenclDevice(); /// Create a \p num_devices set of valid OpenCL devices, regardless of /// platforms. If there are fewer valid devices than requested, then this /// method will return as many as possible. If OpenCL is not in use, this /// method will return an empty array. // static const std::vector> -// CreateOCLDevices(const size_t num_devices); +// CreateOpenclDevices(const size_t num_devices); /// Create a set of valid OpenCL devices, regardless of platforms, assigning /// \p id to each device in sequence. @@ -313,29 +330,11 @@ class Platform { /// return as many as possible. /// If OpenCL is not in use, this method will return an empty array. // static const std::vector> -// CreateOCLDevices(const std::vector &id); - +// CreateOpenclDevices(const std::vector &id); #endif // USE_OPENCL - /// This function is implementd by Caffe (http://caffe.berkeleyvision.org/). - /// This function checks the availability of GPU #device_id. - /// It attempts to create a context on the device by calling cudaFree(0). - /// cudaSetDevice() alone is not sufficient to check the availability. - /// It lazily records device_id, however, does not initialize a - /// context. So it does not know if the host thread has the permission to use - /// the device or not. - /// - /// In a shared environment where the devices are set to EXCLUSIVE_PROCESS - /// or EXCLUSIVE_THREAD mode, cudaSetDevice() returns cudaSuccess - /// even if the device is exclusively occupied by another process or thread. - /// Cuda operations that initialize the context are needed to check - /// the permission. cudaFree(0) is one of those with no side effect, - /// except the context initialization. - static bool CheckDevice(const int device_id); - }; - } // namespace singa #endif // SINGA_CORE_DEVICE_H_ diff --git a/python/singa/device.py b/python/singa/device.py index 749db4c124..9e38b14a52 100644 --- a/python/singa/device.py +++ b/python/singa/device.py @@ -120,20 +120,21 @@ def create_cuda_gpu_on(device_id): return devices[0] -def get_num_ocl_platforms(): - return singa.Platform.GetNumOCLPlatforms() +def get_num_opencl_platforms(): + return singa.Platform.GetNumOpenclPlatforms() -def get_num_ocl_devices(): - return singa.Platform.GetNumOCLDevices() + +def get_num_opencl_devices(): + return singa.Platform.GetNumOpenclDevices() def create_opencl_device(): '''Create the default OpenCL device. - + Returns: a swig converted OpenCL device. ''' - return singa.Platform.GetDefaultOCLDevice() + return singa.Platform.GetDefaultOpenclDevice() default_device = singa.Platform.GetDefaultDevice() diff --git a/python/singa/net.py b/python/singa/net.py index 7824c1cc83..305ecaefa4 100644 --- a/python/singa/net.py +++ b/python/singa/net.py @@ -19,7 +19,7 @@ functions for net info, e.g., parameters. """ -import timeit +from timeit import default_timer as timer from .proto.model_pb2 import kTrain, kEval import tensor import layer @@ -29,6 +29,10 @@ '''For display training information, e.g L1 value of layer data''' verbose = False benchmark = True +forward_time = {} # forward time for each layer +backward_time = {} # backward time for each layer +bp_time = [0, 0, 0] # forward + backward, forward, backward + class FeedForwardNet(object): @@ -129,27 +133,24 @@ def train(self, x, y): Returns: gradients of parameters and the loss and metric values. ''' - out = self.forward(kTrain, x) - l = self.loss.forward(kTrain, out, y) - if self.metric is not None: - m = self.metric.evaluate(out, y) - return self.backward(), (l.l1(), m) - - def train_benchmark(self, x, y): - t0 = timeit.default_timer() - out = self.forward(kTrain, x) - t0 = timeit.default_timer() - t0 - - l = self.loss.forward(kTrain, out, y) - if self.metric is not None: - m = self.metric.evaluate(out, y) - - t1 = timeit.default_timer() - grads = self.backward() - t1 = timeit.default_timer() - t1 - - print("Forward: {0:.4f}\tBackward: {0:.4f}".format(t0, t1)) - return grads, (l.l1(), m) + if benchmark: + global bp_time + t1 = timer() + out = self.forward(kTrain, x) + l = self.loss.forward(kTrain, out, y) + t2 = timer() + ret = self.backward() + t3 = timer() + bp_time[0] += t3 - t1 + bp_time[1] += t2 - t1 + bp_time[2] += t3 - t2 + return ret, (l.l1(), None) + else: + out = self.forward(kTrain, x) + l = self.loss.forward(kTrain, out, y) + if self.metric is not None: + m = self.metric.evaluate(out, y) + return self.backward(), (l.l1(), m) def evaluate(self, x, y): '''Evaluate the loss and metric of the given data. @@ -266,7 +267,15 @@ def forward(self, flag, x, output=[]): output_of_layer.pop(src.name) if len(inputs) == 1: inputs = inputs[0] - out = cur.forward(flag, inputs) + + if benchmark: + global forward_time + start_tick = timer() + out = cur.forward(flag, inputs) + forward_time[cur.name] += timer() - start_tick + else: + out = cur.forward(flag, inputs) + if verbose: disp_src = '+'.join([src.name for src in srcs]) disp_src += '-->' + cur.name @@ -316,7 +325,13 @@ def backward(self): # del output_of_layer[dst.name] if len(grads) == 1: grads = grads[0] - outs, _pgrads = cur.backward(kTrain, grads) + if benchmark: + global backward_time + start_tick = timer() + outs, _pgrads = cur.backward(kTrain, grads) + backward_time[cur.name] += timer() - start_tick + else: + outs, _pgrads = cur.backward(kTrain, grads) pgrads.append(_pgrads) if verbose: disp_src = '+'.join( @@ -338,6 +353,35 @@ def backward(self): ret.extend(pgrad) return ret + def start_benchmark(self): + '''Reset the internal arrays to start benchmark, must be called before + calling the train() function. + ''' + global benchmark, bp_time + benchmark = True + bp_time = [0, 0, 0] + for ly in self.layers: + forward_time[ly.name] = 0 + backward_time[ly.name] = 0 + + def stop_benchmark(self, num): + '''Stop the benchmark and return the time information. + + Args: + num(int), number of total iterations + + Returns: + time for the following procedures within one iteration + foward-backward, forward, backward, [forward of each layer], + [backward of each layer] + ''' + fp = [] + bp = [] + for lyr in self.ordered_layers: + fp.append((lyr.name, forward_time[lyr.name] / num)) + bp.append((lyr.name, backward_time[lyr.name] / num)) + return bp_time[0] / num, bp_time[1] / num, bp_time[2] / num, fp, bp + def save(self, f, buffer_size=10, use_pickle=False): '''Save model parameters using io/snapshot. diff --git a/src/api/core_device.i b/src/api/core_device.i index 04d028afe3..a5b7de6e90 100644 --- a/src/api/core_device.i +++ b/src/api/core_device.i @@ -44,14 +44,14 @@ namespace std{ namespace singa{ class Device { -public: + public: virtual void SetRandSeed(unsigned seed) = 0; std::shared_ptr host(); int id() const; }; class Platform { -public: + public: #if USE_CUDA static int GetNumGPUs(); static const std::vector GetGPUIDs(); @@ -66,9 +66,9 @@ public: #if USE_OPENCL - const int GetNumOCLPlatforms(); - const int GetNumOCLDevices(); - static const std::shared_ptr GetDefaultOCLDevice(); + const int GetNumOpenclPlatforms(); + const int GetNumOpenclDevices(); + static const std::shared_ptr GetDefaultOpenclDevice(); // static const std::vector> // CreateOpenclDevices(const size_t num_devices); // static const std::vector> diff --git a/src/core/device/platform.cc b/src/core/device/platform.cc index 6de0b396ba..8ae15f8604 100644 --- a/src/core/device/platform.cc +++ b/src/core/device/platform.cc @@ -142,12 +142,12 @@ Platform::CreateCudaGPUsOn(const vector &devices, size_t init_size) { #ifdef USE_OPENCL -const int Platform::GetNumOCLPlatforms() { +const int Platform::GetNumOpenclPlatforms() { auto all_platforms = viennacl::ocl::get_platforms(); return (int)all_platforms.size(); } -const int Platform::GetNumOCLDevices() { +const int Platform::GetNumOpenclDevices() { auto all_platforms = viennacl::ocl::get_platforms(); unsigned int total_num_devices = 0; for (auto plat : all_platforms) { @@ -157,12 +157,12 @@ const int Platform::GetNumOCLDevices() { return (int)total_num_devices; } -const std::shared_ptr Platform::GetDefaultOCLDevice() { +const std::shared_ptr Platform::GetDefaultOpenclDevice() { return std::make_shared(); } /* static const std::vector> -Platform::CreateOCLDevices(const size_t num_devices) { +Platform::CreateOpenclDevices(const size_t num_devices) { auto all_platforms = viennacl::ocl::get_platforms(); for (auto plat : all_platforms) { auto all_devices = plat.devices(CL_DEVICE_TYPE_ALL); @@ -172,8 +172,8 @@ Platform::CreateOCLDevices(const size_t num_devices) { } static const std::vector> -Platform::CreateOCLDevices(const std::vector &id) { - +Platform::CreateOpenclDevices(const std::vector &id) { + } */ #endif // USE_OPENCL From ffb8032ace109b22d74d7a4f09c944285440815f Mon Sep 17 00:00:00 2001 From: wangwei Date: Sat, 19 Nov 2016 09:53:03 +0000 Subject: [PATCH 6/6] tested benchmarks; segment fault from vgg+cuda; --- examples/benchmark/README.md | 14 ++++++-- examples/benchmark/run.py | 53 +++++++++++++--------------- examples/benchmark/vgg.py | 8 ++--- python/singa/net.py | 15 ++++---- src/model/layer/cudnn_convolution.cc | 14 ++++++++ 5 files changed, 62 insertions(+), 42 deletions(-) diff --git a/examples/benchmark/README.md b/examples/benchmark/README.md index 2f5855d7ff..4092ac6635 100644 --- a/examples/benchmark/README.md +++ b/examples/benchmark/README.md @@ -1,5 +1,15 @@ #Benchmark scripts -These scripts will create a neural net modelled after the ones specified in [convnet-benchmarks](https://github.com/soumith/convnet-benchmarks/tree/master/caffe/imagenet_winners). +These scripts will test the efficiency of SINGA by training benchmark models pecified in +[convnet-benchmarks](https://github.com/soumith/convnet-benchmarks/tree/master/caffe/imagenet_winners) +over different devices (e.g., CPU and GPU). -To run them, create a python pip virtualenv or anaconda virtual environment as guided by [this article](http://singa.apache.org/en/docs/installation.html#pip-and-anaconda-for-pysinga). Then, execute the python scripts in this folder. +To run them, create a python pip virtualenv or anaconda virtual environment as +guided by [this article](http://singa.apache.org/en/docs/installation.html#pip-and-anaconda-for-pysinga). +Then, execute the `run.py` as + + $ python run.py + +Different models and devices could be tested, please refer to the command line help message, + + $ python run.py -h diff --git a/examples/benchmark/run.py b/examples/benchmark/run.py index 475ba556c4..d35ff5dead 100644 --- a/examples/benchmark/run.py +++ b/examples/benchmark/run.py @@ -14,50 +14,44 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= -''' This model is created following the structure from -https://github.com/soumith/convnet-benchmarks/blob/master/caffe/imagenet_winners/alexnet.prototxt -''' -import sys -import numpy as np +from timeit import timeit as timer +import argparse from singa import device from singa import tensor from singa import optimizer -iterations = 10 -batch_size = 128 -input_shape = (3, 224, 224) -# Time forward, backward, parameter update, per layer (1x forward, 1x backward) -def train(net, dev): +def train(net, dev, num_iter=10, batch_size=128, input_shape=(3, 244, 244)): + '''Train the net for multiple iterations to measure the efficiency. + + Including timer per iteration, forward, backward, parameter update and + timer for each layer.''' + tx = tensor.Tensor((batch_size,) + input_shape, dev) ty = tensor.Tensor((batch_size,), dev) tx.gaussian(1.0, 0.5) ty.set_value(0.0) opt = optimizer.SGD(momentum=0.9) - idx = np.arange(tx.shape[0], dtype = np.int32) - loss = 0.0 - acc = 0.0 - train_time = 0.0 - update_time = 0.0 net.start_benchmark() update = 0 - for b in range(iterations): - grads, (l, a) = net.train_benchmark(tx, ty) + for b in range(num_iter): + print b + grads, (l, a) = net.train(tx, ty) t1 = timer() for (s, p, g) in zip(net.param_names(), net.param_values(), grads): opt.apply_with_lr(0, 0.01, g, p, str(s), b) update += timer() - t1 - t, fp, bp, fps, bps = net.stop_benchmark(iterations) + iter_time, fps, bps = net.stop_benchmark(num_iter) - print "Total iterations = %d" % iterations - print "Average training time per iteration = %.4f" % t - print "Average forward time per iteration = %.4f" % fp - print "Average backward time per iteration = %.4f" % bp - print "Average udpate time per iteration = %.4f" % (update / iterations) + print "Total iterations = %d" % num_iter + print "Average training time per iteration = %.4f" % iter_time[0] + print "Average forward time per iteration = %.4f" % iter_time[1] + print "Average backward time per iteration = %.4f" % iter_time[2] + print "Average udpate time per iteration = %.4f" % (update / num_iter) for (k, v) in fps: print "Forward time for %10s = %.4f" % (k, v) for (k, v) in bps: @@ -65,8 +59,8 @@ def train(net, dev): if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Benchmark SINGA by running' - 'AlexNet/VGG/Overfeat with CPP/CUDA/Opencl') + parser = argparse.ArgumentParser(description='Benchmark SINGA by training' + 'AlexNet/VGG/Overfeat with on CPU/GPU') parser.add_argument('net', choices=['vgg', 'alexnet', 'overfeat'], default='alexnet') parser.add_argument('device', choices=['cpp', 'cuda', 'opencl'], @@ -80,18 +74,19 @@ def train(net, dev): assert args.net == 'overfeat', 'Wrong net type:' + args.net import overfeat as model - use_cpu = False, + use_cpu = False use_opencl = False - if args.device == 'cpu': + if args.device == 'cpp': use_cpu = True dev = device.get_default_device() elif args.device == 'cuda': - dev = device.create_cuda_gpu() + dev = device.create_cuda_gpu_on(2) else: assert args.device == 'opencl', 'Wrong lang: ' + args.device use_opencl = True dev = device.create_opencl_device() + input_shape = (3, 244, 244,) net = model.create_net(input_shape, use_cpu, use_opencl) net.to_device(dev) - train(net, dev) + train(net, dev, input_shape=input_shape) diff --git a/examples/benchmark/vgg.py b/examples/benchmark/vgg.py index 0dcdd93493..0b7d75bc0d 100644 --- a/examples/benchmark/vgg.py +++ b/examples/benchmark/vgg.py @@ -23,6 +23,7 @@ from singa import metric from singa import net as ffnet +ffnet.verbose=True def create_net(input_shape, use_cpu=False, use_ocl=False): if use_cpu: @@ -71,11 +72,10 @@ def create_net(input_shape, use_cpu=False, use_ocl=False): net.add(layer.Dense("fc8", 1000)) for (val, spec) in zip(net.param_values(), net.param_specs()): - filler = spec.filler - if filler.type == 'gaussian': - val.gaussian(filler.mean, filler.std) + if len(val.shape) > 1: + val.gaussian(0, 0.01) else: val.set_value(0) - print spec.name, filler.type, val.l1() + print spec.name, spec.filler.type, val.l1() return net diff --git a/python/singa/net.py b/python/singa/net.py index 305ecaefa4..ec479a98ef 100644 --- a/python/singa/net.py +++ b/python/singa/net.py @@ -28,10 +28,10 @@ '''For display training information, e.g L1 value of layer data''' verbose = False -benchmark = True +benchmark = False forward_time = {} # forward time for each layer backward_time = {} # backward time for each layer -bp_time = [0, 0, 0] # forward + backward, forward, backward +iter_time = [0, 0, 0] # time for one iteration, forward, and backward @@ -141,9 +141,9 @@ def train(self, x, y): t2 = timer() ret = self.backward() t3 = timer() - bp_time[0] += t3 - t1 - bp_time[1] += t2 - t1 - bp_time[2] += t3 - t2 + iter_time[0] += t3 - t1 + iter_time[1] += t2 - t1 + iter_time[2] += t3 - t2 return ret, (l.l1(), None) else: out = self.forward(kTrain, x) @@ -363,6 +363,7 @@ def start_benchmark(self): for ly in self.layers: forward_time[ly.name] = 0 backward_time[ly.name] = 0 + return iter_time, forward_time, backward_time def stop_benchmark(self, num): '''Stop the benchmark and return the time information. @@ -372,7 +373,7 @@ def stop_benchmark(self, num): Returns: time for the following procedures within one iteration - foward-backward, forward, backward, [forward of each layer], + [foward-backward, forward, backward], [forward of each layer], [backward of each layer] ''' fp = [] @@ -380,7 +381,7 @@ def stop_benchmark(self, num): for lyr in self.ordered_layers: fp.append((lyr.name, forward_time[lyr.name] / num)) bp.append((lyr.name, backward_time[lyr.name] / num)) - return bp_time[0] / num, bp_time[1] / num, bp_time[2] / num, fp, bp + return [t / num for t in iter_time], fp, bp def save(self, f, buffer_size=10, use_pickle=False): '''Save model parameters using io/snapshot. diff --git a/src/model/layer/cudnn_convolution.cc b/src/model/layer/cudnn_convolution.cc index 196d1375b1..8245bf01ac 100644 --- a/src/model/layer/cudnn_convolution.cc +++ b/src/model/layer/cudnn_convolution.cc @@ -173,10 +173,23 @@ const Tensor CudnnConvolution::Forward(int flag, const Tensor &input) { Shape shape{batchsize, num_filters_, conv_height_, conv_width_}; Tensor output(shape, dev, dtype); + LOG(ERROR) << "input: " << input.shape(0) << ", " << input.shape(1) << ", " + << input.shape(2) << ", " << input.shape(3); + LOG(ERROR) << "weight: " << weight_.shape(0) << ", " << weight_.shape(1); + LOG(ERROR) << "output: " << output.shape(0) << ", " << output.shape(1) << ", " + << output.shape(2) << ", " << output.shape(3); + output.device()->Exec([input, output, this](Context *ctx) { Block *inblock = input.block(), *outblock = output.block(), *wblock = this->weight_.block(); float alpha = 1.f, beta = 0.f; + /* + LOG(ERROR) << "before conv"; + CHECK(inblock->data() != nullptr); + CHECK(wblock->data() != nullptr); + CHECK(outblock->mutable_data() != nullptr); + CHECK(workspace_.block()->mutable_data() != nullptr); + */ cudnnConvolutionForward(ctx->cudnn_handle, &alpha, this->x_desc_, inblock->data(), this->filter_desc_, wblock->data(), this->conv_desc_, this->fp_alg_, @@ -185,6 +198,7 @@ const Tensor CudnnConvolution::Forward(int flag, const Tensor &input) { this->y_desc_, outblock->mutable_data()); }, {input.block(), weight_.block()}, {output.block()}, workspace_.block()); + // LOG(ERROR) << "before bias"; if (bias_term_) { output.device()->Exec([output, this](Context *ctx) { float beta = 1.f, alpha = 1.0f;