diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index d5bfbd926d..708628b094 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -91,12 +91,11 @@ ENDIF() IF(USE_OPENCL) FIND_PACKAGE(OpenCL REQUIRED) - IF(NOT OpenCL_FOUND) + IF(NOT OPENCL_FOUND) MESSAGE(SEND_ERROR "OpenCL was requested, but not found.") ELSE() - #MESSAGE(STATUS "Found OpenCL headers at ${OpenCL_INCLUDE_DIRS}") - INCLUDE_DIRECTORIES(SYSTEM ${OpenCL_INCLUDE_DIR}) - LIST(APPEND SINGA_LINKER_LIBS ${OpenCL_LIBRARIES}) + INCLUDE_DIRECTORIES(SYSTEM ${OPENCL_INCLUDE_DIR}) + LIST(APPEND SINGA_LINKER_LIBS ${OPENCL_LIBRARIES}) FIND_PACKAGE(ViennaCL REQUIRED) IF(NOT ViennaCL_FOUND) MESSAGE(SEND_ERROR "ViennaCL is required if OpenCL is enabled.") diff --git a/cmake/Thirdparty/FindOpenCL.cmake b/cmake/Thirdparty/FindOpenCL.cmake new file mode 100644 index 0000000000..c358d8a1a4 --- /dev/null +++ b/cmake/Thirdparty/FindOpenCL.cmake @@ -0,0 +1,88 @@ +# - Find the OpenCL headers and library +# +# Defines the following if found: +# OPENCL_FOUND : TRUE if found, FALSE otherwise +# OPENCL_INCLUDE_DIRS : Include directories for OpenCL +# OPENCL_LIBRARIES : The libraries to link against +# +# The user can set the OPENCLROOT environment variable to help finding OpenCL +# if it is installed in a non-standard place. + +set(ENV_ATISTREAMSDKROOT "$ENV{ATISTREAMSDKROOT}") +if(ENV_ATISTREAMSDKROOT) + set(ENV_OPENCLROOT "$ENV{ATISTREAMSDKROOT}") +endif(ENV_ATISTREAMSDKROOT) + +set(ENV_AMDAPPSDKROOT "$ENV{AMDAPPSDKROOT}") +if(ENV_AMDAPPSDKROOT) + set(ENV_OPENCLROOT "$ENV{AMDAPPSDKROOT}") +endif(ENV_AMDAPPSDKROOT) + +set(ENV_INTELOCLSDKROOT "$ENV{INTELOCLSDKROOT}") +if(ENV_INTELOCLSDKROOT) + set(ENV_OPENCLROOT "$ENV{INTELOCLSDKROOT}") +endif(ENV_INTELOCLSDKROOT) + +set(ENV_OPENCLROOT2 "$ENV{OPENCLROOT}") +if(ENV_OPENCLROOT2) + set(ENV_OPENCLROOT "$ENV{OPENCLROOT}") +endif(ENV_OPENCLROOT2) + +if(ENV_OPENCLROOT) + find_path( + OPENCL_INCLUDE_DIR + NAMES CL/cl.h OpenCL/cl.h + PATHS "${ENV_OPENCLROOT}/include" + #NO_DEFAULT_PATH #uncomment this is you wish to surpress the use of default paths for OpenCL + ) + + if (("${CMAKE_SYSTEM_NAME}" MATCHES "Linux") OR ("${CMAKE_SYSTEM_NAME}" MATCHES "Windows")) + if(CMAKE_SIZEOF_VOID_P EQUAL 4) + set(OPENCL_LIB_SEARCH_PATH + "${OPENCL_LIB_SEARCH_PATH}" + "${ENV_OPENCLROOT}/lib/x86") + else(CMAKE_SIZEOF_VOID_P EQUAL 4) + set(OPENCL_LIB_SEARCH_PATH + "${OPENCL_LIB_SEARCH_PATH}" + "${ENV_OPENCLROOT}/lib/x86_64") + endif(CMAKE_SIZEOF_VOID_P EQUAL 4) + endif(("${CMAKE_SYSTEM_NAME}" MATCHES "Linux") OR ("${CMAKE_SYSTEM_NAME}" MATCHES "Windows")) + find_library( + OPENCL_LIBRARY + NAMES OpenCL + PATHS "${OPENCL_LIB_SEARCH_PATH}" + #NO_DEFAULT_PATH #uncomment this is you wish to surpress the use of default paths for OpenCL + ) +else(ENV_OPENCLROOT) + find_path( + OPENCL_INCLUDE_DIR + NAMES CL/cl.h OpenCL/cl.h + PATHS "${PROJECT_SOURCE_DIR}" #use the CL/ include folder provided with ViennaCL + ) + + find_library( + OPENCL_LIBRARY + NAMES OpenCL + ) +endif(ENV_OPENCLROOT) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args( + OPENCL + DEFAULT_MSG + OPENCL_LIBRARY OPENCL_INCLUDE_DIR + ) + +if(OPENCL_FOUND) + set(OPENCL_INCLUDE_DIRS "${OPENCL_INCLUDE_DIR}") + set(OPENCL_LIBRARIES "${OPENCL_LIBRARY}") +else(OPENCL_FOUND) + set(OPENCL_INCLUDE_DIRS) + set(OPENCL_LIBRARIES) +endif(OPENCL_FOUND) + +mark_as_advanced( + OPENCL_INCLUDE_DIR + OPENCL_LIBRARY + ) + diff --git a/cmake/Thirdparty/FindViennaCL.cmake b/cmake/Thirdparty/FindViennaCL.cmake index 263c80fdca..f18c4be1d8 100644 --- a/cmake/Thirdparty/FindViennaCL.cmake +++ b/cmake/Thirdparty/FindViennaCL.cmake @@ -1,9 +1,14 @@ # This file is retrieved from caffe/cmake/Modules/FindViennaCL.cmake. +# from the opencl branch on BVLC/Caffe. SET(ViennaCL_WITH_OPENCL TRUE) SET(VIENNACL_INCLUDE_SEARCH_PATHS + viennacl + viennacl-dev .. + ../viennacl + ../viennacl-dev /usr/include /usr/local/include /opt/ViennaCL/include @@ -15,7 +20,7 @@ FIND_PATH(ViennaCL_INCLUDE_DIR NAMES viennacl/forwards.h PATHS ${VIENNACL_INCLUD SET(ViennaCL_FOUND ON) -# Check include files +# Check include files IF(NOT ViennaCL_INCLUDE_DIR) SET(ViennaCL_FOUND OFF) MESSAGE(STATUS "Could not find ViennaCL include. Turning ViennaCL_FOUND off") @@ -33,6 +38,10 @@ ENDIF (ViennaCL_FOUND) IF(ViennaCL_WITH_OPENCL) find_package(OpenCL REQUIRED) + IF(NOT OPENCL_INCLUDE_DIRS) + MESSAGE(FATAL_ERROR "Could not find OpenCL include.") + ENDIF() + MESSAGE(STATUS "Found OpenCL include: ${OPENCL_INCLUDE_DIRS}") ENDIF(ViennaCL_WITH_OPENCL) set(ViennaCL_INCLUDE_DIRS ${ViennaCL_INCLUDE_DIR} ${OPENCL_INCLUDE_DIRS}) diff --git a/examples/benchmark/README.md b/examples/benchmark/README.md new file mode 100644 index 0000000000..4092ac6635 --- /dev/null +++ b/examples/benchmark/README.md @@ -0,0 +1,15 @@ +#Benchmark scripts + +These scripts will test the efficiency of SINGA by training benchmark models pecified in +[convnet-benchmarks](https://github.com/soumith/convnet-benchmarks/tree/master/caffe/imagenet_winners) +over different devices (e.g., CPU and GPU). + +To run them, create a python pip virtualenv or anaconda virtual environment as +guided by [this article](http://singa.apache.org/en/docs/installation.html#pip-and-anaconda-for-pysinga). +Then, execute the `run.py` as + + $ python run.py + +Different models and devices could be tested, please refer to the command line help message, + + $ python run.py -h diff --git a/examples/benchmark/alexnet.py b/examples/benchmark/alexnet.py new file mode 100644 index 0000000000..48c812ab53 --- /dev/null +++ b/examples/benchmark/alexnet.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +''' This model is created following the structure from +https://github.com/soumith/convnet-benchmarks/blob/master/caffe/imagenet_winners/alexnet.prototxt +''' + +from singa import layer +from singa import loss +from singa import metric +from singa import net as ffnet + + +def create_net(input_shape, use_cpu=False, use_ocl=False): + if use_cpu: + layer.engine = 'singacpp' + if use_ocl: + layer.engine = 'singacl' + + net = ffnet.FeedForwardNet(loss.SoftmaxCrossEntropy(), metric.Accuracy()) + + # Conv 1 + net.add(layer.Conv2D("conv1", 64, 11, 4, pad=2, + input_sample_shape=input_shape)) + net.add(layer.Activation("conv1/relu" )) + net.add(layer.MaxPooling2D("pool1/3x3_s2", 3, 2)) + + # Conv 2 + net.add(layer.Conv2D("conv1/5x5_s1", 192, 5, 1, pad=2)) + net.add(layer.Activation("conv2/relu")) + net.add(layer.MaxPooling2D("pool2/3x3_s2", 3, 2)) + + # Conv 3 + net.add(layer.Conv2D("conv3/3x3_s1", 384, 3, 1, pad=1)) + net.add(layer.Activation("conv3/relu")) + + # Conv 4 + net.add(layer.Conv2D("conv4/3x3_s1", 256, 3, 1, pad=1)) + net.add(layer.Activation("conv4/relu")) + + # Conv 5 + net.add(layer.Conv2D("conv5/3x3_s1", 256, 3, 1, pad=1)) + net.add(layer.Activation("conv5/relu")) + net.add(layer.MaxPooling2D("pool5/3x3_s2", 3, 2)) + + # L2 Norm -> Inner product + net.add(layer.Flatten("flat")) + net.add(layer.Dense("fc6", 4096)) + net.add(layer.Activation("fc6/relu6")) + + net.add(layer.Dense("fc7", 4096)) + net.add(layer.Activation("fc7/relu7")) + + net.add(layer.Dense("fc8", 1000)) + + for (val, spec) in zip(net.param_values(), net.param_specs()): + filler = spec.filler + if filler.type == 'gaussian': + val.gaussian(filler.mean, filler.std) + else: + val.set_value(0) + print spec.name, filler.type, val.l1() + + return net diff --git a/examples/benchmark/overfeat.py b/examples/benchmark/overfeat.py new file mode 100644 index 0000000000..ca7a99d8fc --- /dev/null +++ b/examples/benchmark/overfeat.py @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +''' This model is created following the structure from +https://github.com/soumith/convnet-benchmarks/blob/master/caffe/imagenet_winners/overfeat.prototxt +''' + +from singa import layer +from singa import loss +from singa import metric +from singa import net as ffnet + + +def create_net(input_shape, use_cpu=False, use_ocl=False): + if use_cpu: + layer.engine = 'singacpp' + if use_ocl: + layer.engine = 'singacl' + + net = ffnet.FeedForwardNet(loss.SoftmaxCrossEntropy(), metric.Accuracy()) + + # Conv 1 + net.add(layer.Conv2D("conv1", 96, 11, 4, input_sample_shape=input_shape)) + net.add(layer.Activation("conv1/relu" )) + net.add(layer.MaxPooling2D("pool1/2x2_s2", 2, 2, border_mode='valid')) + + # Conv 2 + net.add(layer.Conv2D("conv1/5x5_s1", 256, 5, 1)) + net.add(layer.Activation("conv2/relu")) + net.add(layer.MaxPooling2D("pool2/2x2_s2", 2, 2, border_mode='valid')) + + # Conv 3 + net.add(layer.Conv2D("conv3/3x3_s1", 512, 3, 1, pad=1)) + net.add(layer.Activation("conv3/relu")) + + # Conv 4 + net.add(layer.Conv2D("conv4/3x3_s1", 1024, 3, 1, pad=1)) + net.add(layer.Activation("conv4/relu")) + + # Conv 5 + net.add(layer.Conv2D("conv5/3x3_s1", 1024, 3, 1, pad=1)) + net.add(layer.Activation("conv5/relu")) + net.add(layer.MaxPooling2D("pool5/2x2_s2", 2, 2, border_mode='valid')) + + # L2 Norm -> Inner product + net.add(layer.Flatten("flat")) + net.add(layer.Dense("fc6", 3072)) + net.add(layer.Activation("fc6/relu6")) + + net.add(layer.Dense("fc7", 4096)) + net.add(layer.Activation("fc7/relu7")) + + net.add(layer.Dense("fc8", 1000)) + + for (val, spec) in zip(net.param_values(), net.param_specs()): + filler = spec.filler + if filler.type == 'gaussian': + val.gaussian(filler.mean, filler.std) + else: + val.set_value(0) + print spec.name, filler.type, val.l1() + + return net diff --git a/examples/benchmark/run.py b/examples/benchmark/run.py new file mode 100644 index 0000000000..d35ff5dead --- /dev/null +++ b/examples/benchmark/run.py @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from timeit import timeit as timer +import argparse + +from singa import device +from singa import tensor +from singa import optimizer + + +def train(net, dev, num_iter=10, batch_size=128, input_shape=(3, 244, 244)): + '''Train the net for multiple iterations to measure the efficiency. + + Including timer per iteration, forward, backward, parameter update and + timer for each layer.''' + + tx = tensor.Tensor((batch_size,) + input_shape, dev) + ty = tensor.Tensor((batch_size,), dev) + tx.gaussian(1.0, 0.5) + ty.set_value(0.0) + + opt = optimizer.SGD(momentum=0.9) + + net.start_benchmark() + update = 0 + for b in range(num_iter): + print b + grads, (l, a) = net.train(tx, ty) + t1 = timer() + for (s, p, g) in zip(net.param_names(), net.param_values(), grads): + opt.apply_with_lr(0, 0.01, g, p, str(s), b) + update += timer() - t1 + iter_time, fps, bps = net.stop_benchmark(num_iter) + + print "Total iterations = %d" % num_iter + print "Average training time per iteration = %.4f" % iter_time[0] + print "Average forward time per iteration = %.4f" % iter_time[1] + print "Average backward time per iteration = %.4f" % iter_time[2] + print "Average udpate time per iteration = %.4f" % (update / num_iter) + for (k, v) in fps: + print "Forward time for %10s = %.4f" % (k, v) + for (k, v) in bps: + print "Backward time for %10s = %.4f" % (k, v) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Benchmark SINGA by training' + 'AlexNet/VGG/Overfeat with on CPU/GPU') + parser.add_argument('net', choices=['vgg', 'alexnet', 'overfeat'], + default='alexnet') + parser.add_argument('device', choices=['cpp', 'cuda', 'opencl'], + default='cuda') + args = parser.parse_args() + if args.net == 'vgg': + import vgg as model + elif args.net == 'alexnet': + import alexnet as model + else: + assert args.net == 'overfeat', 'Wrong net type:' + args.net + import overfeat as model + + use_cpu = False + use_opencl = False + + if args.device == 'cpp': + use_cpu = True + dev = device.get_default_device() + elif args.device == 'cuda': + dev = device.create_cuda_gpu_on(2) + else: + assert args.device == 'opencl', 'Wrong lang: ' + args.device + use_opencl = True + dev = device.create_opencl_device() + input_shape = (3, 244, 244,) + net = model.create_net(input_shape, use_cpu, use_opencl) + net.to_device(dev) + train(net, dev, input_shape=input_shape) diff --git a/examples/benchmark/vgg.py b/examples/benchmark/vgg.py new file mode 100644 index 0000000000..0b7d75bc0d --- /dev/null +++ b/examples/benchmark/vgg.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +''' This model is created following the structure from +https://github.com/soumith/convnet-benchmarks/blob/master/caffe/imagenet_winners/vgg_a.prototxt +''' + +from singa import layer +from singa import loss +from singa import metric +from singa import net as ffnet + +ffnet.verbose=True + +def create_net(input_shape, use_cpu=False, use_ocl=False): + if use_cpu: + layer.engine = 'singacpp' + if use_ocl: + layer.engine = 'singacl' + + net = ffnet.FeedForwardNet(loss.SoftmaxCrossEntropy(), metric.Accuracy()) + + net.add(layer.Conv2D("conv1/3x3_s1", 64, 3, 1, pad=1, + input_sample_shape=input_shape)) + net.add(layer.Activation("conv1/relu")) + net.add(layer.MaxPooling2D("pool1/2x2_s2", 2, 2, border_mode='valid')) + + net.add(layer.Conv2D("conv2/3x3_s1", 128, 3, 1, pad=1)) + net.add(layer.Activation("conv2/relu")) + net.add(layer.MaxPooling2D("pool2/2x2_s2", 2, 2, border_mode='valid')) + + net.add(layer.Conv2D("conv3/3x3_s1", 256, 3, 1, pad=1)) + net.add(layer.Activation("conv3/relu")) + # No pooling layer here. + + net.add(layer.Conv2D("conv4/3x3_s1", 256, 3, 1, pad=1)) + net.add(layer.Activation("conv4/relu")) + net.add(layer.MaxPooling2D("pool3/2x2_s2", 2, 2, border_mode='valid')) + + net.add(layer.Conv2D("conv5/3x3_s1", 512, 3, 1, pad=1)) + net.add(layer.Activation("conv5/relu")) + # No pooling layer here. + + net.add(layer.Conv2D("conv6/3x3_s1", 512, 3, 1, pad=1)) + net.add(layer.Activation("conv6/relu")) + net.add(layer.MaxPooling2D("pool4/2x2_s2", 2, 2, border_mode='valid')) + + net.add(layer.Conv2D("conv7/3x3_s1", 512, 3, 1, pad=1)) + net.add(layer.Activation("conv7/relu")) + # No pooling layer here. + + net.add(layer.Conv2D("conv8/3x3_s1", 512, 3, 1, pad=1)) + net.add(layer.Activation("conv8/relu")) + net.add(layer.MaxPooling2D("pool5/2x2_s2", 2, 2, border_mode='valid')) + + net.add(layer.Flatten('flat')) + net.add(layer.Dense("fc6", 4096)) + net.add(layer.Dense("fc7", 4096)) + net.add(layer.Dense("fc8", 1000)) + + for (val, spec) in zip(net.param_values(), net.param_specs()): + if len(val.shape) > 1: + val.gaussian(0, 0.01) + else: + val.set_value(0) + print spec.name, spec.filler.type, val.l1() + + return net diff --git a/include/singa/core/device.h b/include/singa/core/device.h index 0fecc6d8f2..6bb8193d29 100644 --- a/include/singa/core/device.h +++ b/include/singa/core/device.h @@ -50,7 +50,7 @@ namespace singa { /// There are three types of devices distinguished by their programming /// languages, namely cpp, cuda and opencl. class Device { - public: + public: // Device() = default; virtual ~Device() {} /// Constructor with device ID, num of executors (e.g., cuda streams), @@ -102,10 +102,10 @@ class Device { int id() const { return id_; } - private: +private: Device() {}; - protected: +protected: /// Execute one operation on one executor. virtual void DoExec(function&& fn, int executor) = 0; @@ -203,7 +203,7 @@ class CudaGPU : public Device { #ifdef USE_OPENCL // Implement Device using OpenCL libs. class OpenclDevice : public singa::Device { -public: + public: // TODO: Constructor arguments to consider: // Path to kernel sources? @@ -218,7 +218,7 @@ class OpenclDevice : public singa::Device { CopyDirection direction, int dst_offset = 0, int src_offset = 0) override; -protected: + protected: /// The OpenCL device that this object represents. /// Each OpenclDevice contains exactly one cl::Device for the lifetime of the /// object. @@ -248,7 +248,7 @@ class OpenclDevice : public singa::Device { /// This has the effect of freeing up device memory. void Free(void* ptr) override; -private: + private: static const std::string cl_src_path; }; @@ -260,7 +260,7 @@ class OpenclDevice : public singa::Device { /// return something that indicates their absence (for example, 0 devices); /// however they should always be available regardless of compile-time switches. class Platform { -public: + public: /// Return the defualt host device static std::shared_ptr GetDefaultDevice() { @@ -290,22 +290,6 @@ class Platform { /// Create a set of CudaGPU Device using given GPU IDs. static const std::vector> CreateCudaGPUsOn(const std::vector &devices, size_t init_size = 0); -#endif // USE_CUDA - - /// Create a \p num_devices set of valid OpenCL devices, regardless of - /// platforms. If there are fewer valid devices than requested, then this - /// method will return as many as possible.If OpenCL is not in use, this - /// method will return an empty array. - const std::vector > CreateOpenclDevices( - const size_t num_devices); - - /// Create a set of valid OpenCL devices, regardless of platforms, assigning - /// \p id to each device in sequence. - /// If there are fewer valid devices than requested, then this method will - /// return as many as possible. - /// If OpenCL is not in use, this method will return an empty array. - const std::vector > - CreateOpenclDevices(const vector &id); /// This function is implementd by Caffe (http://caffe.berkeleyvision.org/). /// This function checks the availability of GPU #device_id. @@ -322,9 +306,34 @@ class Platform { /// the permission. cudaFree(0) is one of those with no side effect, /// except the context initialization. static bool CheckDevice(const int device_id); +#endif // USE_CUDA -}; +#ifdef USE_OPENCL + + const int GetNumOpenclPlatforms(); + + const int GetNumOpenclDevices(); + + static const std::shared_ptr GetDefaultOpenclDevice(); + + /// Create a \p num_devices set of valid OpenCL devices, regardless of + /// platforms. If there are fewer valid devices than requested, then this + /// method will return as many as possible. If OpenCL is not in use, this + /// method will return an empty array. +// static const std::vector> +// CreateOpenclDevices(const size_t num_devices); + + /// Create a set of valid OpenCL devices, regardless of platforms, assigning + /// \p id to each device in sequence. + /// If there are fewer valid devices than requested, then this method will + /// return as many as possible. + /// If OpenCL is not in use, this method will return an empty array. +// static const std::vector> +// CreateOpenclDevices(const std::vector &id); +#endif // USE_OPENCL + +}; } // namespace singa diff --git a/python/singa/device.py b/python/singa/device.py index f250f9e6af..9e38b14a52 100644 --- a/python/singa/device.py +++ b/python/singa/device.py @@ -120,6 +120,23 @@ def create_cuda_gpu_on(device_id): return devices[0] +def get_num_opencl_platforms(): + return singa.Platform.GetNumOpenclPlatforms() + + +def get_num_opencl_devices(): + return singa.Platform.GetNumOpenclDevices() + + +def create_opencl_device(): + '''Create the default OpenCL device. + + Returns: + a swig converted OpenCL device. + ''' + return singa.Platform.GetDefaultOpenclDevice() + + default_device = singa.Platform.GetDefaultDevice() diff --git a/python/singa/layer.py b/python/singa/layer.py index 583126a287..8a75161835 100644 --- a/python/singa/layer.py +++ b/python/singa/layer.py @@ -75,6 +75,7 @@ class Layer(object): 1. construct layer without input_sample_shapes, goto 2; construct layer with input_sample_shapes, goto 3; 2. call setup to create the parameters and setup other meta fields +w 3. call forward or access layer members 4. call backward and get parameters for update @@ -350,7 +351,7 @@ def __init__(self, name, nb_kernels, kernel=3, stride=1, border_mode='same', self.conf.param.extend([bspecs]) self.param_specs.append(bspecs) - _check_engine(engine, ['cudnn', 'singacpp']) + _check_engine(engine, ['cudnn', 'singacpp', 'singacl']) self.layer = _create_layer(engine, 'Convolution') if input_sample_shape is not None: self.setup(input_sample_shape) @@ -407,7 +408,7 @@ def __init__(self, name, mode, kernel=3, stride=2, border_mode='same', conf = self.conf.pooling_conf conf = _set_kernel_stride_pad(conf, kernel, stride, border_mode, pad) conf.pool = mode - _check_engine(engine, ['cudnn', 'singacpp']) + _check_engine(engine, ['cudnn', 'singacpp', 'singacl']) self.layer = _create_layer(engine, 'Pooling') if input_sample_shape is not None: self.setup(input_sample_shape) diff --git a/python/singa/net.py b/python/singa/net.py index 36c70f8aea..ec479a98ef 100644 --- a/python/singa/net.py +++ b/python/singa/net.py @@ -19,7 +19,7 @@ functions for net info, e.g., parameters. """ - +from timeit import default_timer as timer from .proto.model_pb2 import kTrain, kEval import tensor import layer @@ -28,6 +28,11 @@ '''For display training information, e.g L1 value of layer data''' verbose = False +benchmark = False +forward_time = {} # forward time for each layer +backward_time = {} # backward time for each layer +iter_time = [0, 0, 0] # time for one iteration, forward, and backward + class FeedForwardNet(object): @@ -128,11 +133,24 @@ def train(self, x, y): Returns: gradients of parameters and the loss and metric values. ''' - out = self.forward(kTrain, x) - l = self.loss.forward(kTrain, out, y) - if self.metric is not None: - m = self.metric.evaluate(out, y) - return self.backward(), (l.l1(), m) + if benchmark: + global bp_time + t1 = timer() + out = self.forward(kTrain, x) + l = self.loss.forward(kTrain, out, y) + t2 = timer() + ret = self.backward() + t3 = timer() + iter_time[0] += t3 - t1 + iter_time[1] += t2 - t1 + iter_time[2] += t3 - t2 + return ret, (l.l1(), None) + else: + out = self.forward(kTrain, x) + l = self.loss.forward(kTrain, out, y) + if self.metric is not None: + m = self.metric.evaluate(out, y) + return self.backward(), (l.l1(), m) def evaluate(self, x, y): '''Evaluate the loss and metric of the given data. @@ -249,7 +267,15 @@ def forward(self, flag, x, output=[]): output_of_layer.pop(src.name) if len(inputs) == 1: inputs = inputs[0] - out = cur.forward(flag, inputs) + + if benchmark: + global forward_time + start_tick = timer() + out = cur.forward(flag, inputs) + forward_time[cur.name] += timer() - start_tick + else: + out = cur.forward(flag, inputs) + if verbose: disp_src = '+'.join([src.name for src in srcs]) disp_src += '-->' + cur.name @@ -299,7 +325,13 @@ def backward(self): # del output_of_layer[dst.name] if len(grads) == 1: grads = grads[0] - outs, _pgrads = cur.backward(kTrain, grads) + if benchmark: + global backward_time + start_tick = timer() + outs, _pgrads = cur.backward(kTrain, grads) + backward_time[cur.name] += timer() - start_tick + else: + outs, _pgrads = cur.backward(kTrain, grads) pgrads.append(_pgrads) if verbose: disp_src = '+'.join( @@ -321,6 +353,36 @@ def backward(self): ret.extend(pgrad) return ret + def start_benchmark(self): + '''Reset the internal arrays to start benchmark, must be called before + calling the train() function. + ''' + global benchmark, bp_time + benchmark = True + bp_time = [0, 0, 0] + for ly in self.layers: + forward_time[ly.name] = 0 + backward_time[ly.name] = 0 + return iter_time, forward_time, backward_time + + def stop_benchmark(self, num): + '''Stop the benchmark and return the time information. + + Args: + num(int), number of total iterations + + Returns: + time for the following procedures within one iteration + [foward-backward, forward, backward], [forward of each layer], + [backward of each layer] + ''' + fp = [] + bp = [] + for lyr in self.ordered_layers: + fp.append((lyr.name, forward_time[lyr.name] / num)) + bp.append((lyr.name, backward_time[lyr.name] / num)) + return [t / num for t in iter_time], fp, bp + def save(self, f, buffer_size=10, use_pickle=False): '''Save model parameters using io/snapshot. diff --git a/src/api/config.i.in b/src/api/config.i.in index cea35171d1..05ddf6ed50 100644 --- a/src/api/config.i.in +++ b/src/api/config.i.in @@ -1,6 +1,7 @@ // Pass in cmake configurations to swig #cmakedefine01 USE_CUDA #cmakedefine01 USE_CUDNN +#cmakedefine01 USE_OPENCL #cmakedefine01 USE_PYTHON #cmakedefine01 USE_JAVA #cmakedefine CUDNN_VERSION ${CUDNN_VERSION} diff --git a/src/api/core_device.i b/src/api/core_device.i index a9bb840cb3..a5b7de6e90 100644 --- a/src/api/core_device.i +++ b/src/api/core_device.i @@ -44,7 +44,7 @@ namespace std{ namespace singa{ class Device { - public: + public: virtual void SetRandSeed(unsigned seed) = 0; std::shared_ptr host(); int id() const; @@ -58,11 +58,24 @@ class Platform { static const std::pair GetGPUMemSize(const int device); static const std::vector> GetGPUMemSize(); static const std::string DeviceQuery(int id, bool verbose = false); - static const std::vector > + static const std::vector> CreateCudaGPUs(const size_t num_devices, size_t init_size = 0); static const std::vector> CreateCudaGPUsOn(const std::vector &devices, size_t init_size = 0); #endif // USE_CUDA + +#if USE_OPENCL + + const int GetNumOpenclPlatforms(); + const int GetNumOpenclDevices(); + static const std::shared_ptr GetDefaultOpenclDevice(); +// static const std::vector> +// CreateOpenclDevices(const size_t num_devices); +// static const std::vector> +// CreateOpenclDevices(); + +#endif // USE_OPENCL + static std::shared_ptr GetDefaultDevice(); }; diff --git a/src/core/device/platform.cc b/src/core/device/platform.cc index eb02c5bb1e..8ae15f8604 100644 --- a/src/core/device/platform.cc +++ b/src/core/device/platform.cc @@ -19,11 +19,12 @@ #include "singa/core/device.h" #include "singa/singa_config.h" - -#ifdef USE_CUDA +#include "singa/utils/opencl_utils.h" namespace singa { +#ifdef USE_CUDA + int Platform::GetNumGPUs() { int count; CUDA_CHECK(cudaGetDeviceCount(&count)); @@ -109,7 +110,7 @@ const string Platform::DeviceQuery(int device, bool verbose) { return out.str(); } -const vector > +const vector> Platform::CreateCudaGPUs(const size_t num_devices, size_t init_size) { const vector gpus = GetGPUIDs(); CHECK_LE(num_devices, gpus.size()); @@ -117,7 +118,7 @@ Platform::CreateCudaGPUs(const size_t num_devices, size_t init_size) { return CreateCudaGPUsOn(use_gpus, init_size); } -const vector > +const vector> Platform::CreateCudaGPUsOn(const vector &devices, size_t init_size) { MemPoolConf conf; if (init_size > 0) @@ -137,8 +138,46 @@ Platform::CreateCudaGPUsOn(const vector &devices, size_t init_size) { return ret; } -} // namespace singa - #endif // USE_CUDA -#endif \ No newline at end of file +#ifdef USE_OPENCL + +const int Platform::GetNumOpenclPlatforms() { + auto all_platforms = viennacl::ocl::get_platforms(); + return (int)all_platforms.size(); +} + +const int Platform::GetNumOpenclDevices() { + auto all_platforms = viennacl::ocl::get_platforms(); + unsigned int total_num_devices = 0; + for (auto plat : all_platforms) { + auto all_devices = plat.devices(CL_DEVICE_TYPE_ALL); + total_num_devices += all_devices.size(); + } + return (int)total_num_devices; +} + +const std::shared_ptr Platform::GetDefaultOpenclDevice() { + return std::make_shared(); +} +/* +static const std::vector> +Platform::CreateOpenclDevices(const size_t num_devices) { + auto all_platforms = viennacl::ocl::get_platforms(); + for (auto plat : all_platforms) { + auto all_devices = plat.devices(CL_DEVICE_TYPE_ALL); + total_num_devices += all_devices.size(); + } + return (int)total_num_devices; +} + +static const std::vector> +Platform::CreateOpenclDevices(const std::vector &id) { + +} +*/ +#endif // USE_OPENCL + +} // namespace singa + +#endif diff --git a/src/core/tensor/tensor_math_opencl.h b/src/core/tensor/tensor_math_opencl.h index a209de4a57..c939dbbe07 100644 --- a/src/core/tensor/tensor_math_opencl.h +++ b/src/core/tensor/tensor_math_opencl.h @@ -440,36 +440,17 @@ void Amin(const size_t num, const Block* in, size_t* out, C out[0] = temp[0]; delete temp; } - +*/ template<> void Asum(const size_t num, const Block* in, float* out, Context* ctx) { - cl_int status = CL_SUCCESS; - - std::string kname = "clkernel_asum"; - auto kernel = ctx->kernels->at(kname); - - cl::Buffer inbuf = *(static_cast(in->mutable_data())); - - size_t size = sizeof(float) * num; - cl::Buffer outval(ctx->ocl_ctx, CL_MEM_WRITE_ONLY, size, nullptr, &status); - OCL_CHECK(status, "Failed to create buffer!"); - - kernel.setArg(0, (cl_int)num); - kernel.setArg(1, inbuf); - kernel.setArg(2, outval); - kernel.setArg(3, cl::Local(size)); + viennacl::vector v_in((const cl_mem)in->data(), num); - status = ctx->ocl_cmdq.enqueueNDRangeKernel(kernel, cl::NDRange(0), cl::NDRange(num)); - OCL_CHECK(status, "Failed to enqueue kernel function!"); + viennacl::vector temp = viennacl::linalg::element_fabs(v_in); - float* temp = new float[num]; - status = ctx->ocl_cmdq.enqueueReadBuffer(outval, CL_TRUE, 0, size, temp); - OCL_CHECK(status, "Failed to read from buffer!"); - out[0] = temp[0]; - delete temp; + out[0] = viennacl::linalg::sum(temp); } -*/ + /// out = alpha * in + out template<> void Axpy(const size_t num, const float alpha, const Block* in, Block* out, Context* ctx) { @@ -528,7 +509,7 @@ void GEMV(bool trans, const size_t m, const size_t n, const } /// multiply a matrix with a diagonal matrix constructed using values from 'v'. -/// if matrix_lef_side is true, do M*v; else do v*M +/// if matrix_left_side is true, do M*v; else do v*M template<> void DGMM(bool side_right, const size_t nrow, const size_t ncol, @@ -541,9 +522,9 @@ void DGMM(bool side_right, auto diag = viennacl::diag(v_buf); if (side_right) { - out_buf = viennacl::linalg::prod(diag, M_buf); - } else { out_buf = viennacl::linalg::prod(M_buf, diag); + } else { + out_buf = viennacl::linalg::prod(diag, M_buf); } } @@ -577,9 +558,9 @@ void GEMM(const bool transA, const bool transB, template <> -void ComputeCrossEntropy(const size_t batchsize, const size_t dim, - const Block *p, const Block *t, Block *loss, - Context *ctx) { +void ComputeCrossEntropy(bool int_target, const size_t batchsize, + const size_t dim, const Block *p, const Block *t, + Block *loss, Context *ctx) { auto ocl_ctx = get_context(ctx->vcl_ctx_id); auto kernel = ocl_ctx.get_kernel("tensor_math_opencl.cl", "clkernel_crossentropy"); @@ -592,7 +573,7 @@ void ComputeCrossEntropy(const size_t batchsize, const size template <> -void SoftmaxCrossEntropyBwd(const size_t batchsize, const size_t dim, +void SoftmaxCrossEntropyBwd(bool int_target, const size_t batchsize, const size_t dim, const Block *p, const Block *t, Block *grad, Context *ctx) { auto ocl_ctx = get_context(ctx->vcl_ctx_id); diff --git a/src/model/layer/cudnn_convolution.cc b/src/model/layer/cudnn_convolution.cc index 196d1375b1..8245bf01ac 100644 --- a/src/model/layer/cudnn_convolution.cc +++ b/src/model/layer/cudnn_convolution.cc @@ -173,10 +173,23 @@ const Tensor CudnnConvolution::Forward(int flag, const Tensor &input) { Shape shape{batchsize, num_filters_, conv_height_, conv_width_}; Tensor output(shape, dev, dtype); + LOG(ERROR) << "input: " << input.shape(0) << ", " << input.shape(1) << ", " + << input.shape(2) << ", " << input.shape(3); + LOG(ERROR) << "weight: " << weight_.shape(0) << ", " << weight_.shape(1); + LOG(ERROR) << "output: " << output.shape(0) << ", " << output.shape(1) << ", " + << output.shape(2) << ", " << output.shape(3); + output.device()->Exec([input, output, this](Context *ctx) { Block *inblock = input.block(), *outblock = output.block(), *wblock = this->weight_.block(); float alpha = 1.f, beta = 0.f; + /* + LOG(ERROR) << "before conv"; + CHECK(inblock->data() != nullptr); + CHECK(wblock->data() != nullptr); + CHECK(outblock->mutable_data() != nullptr); + CHECK(workspace_.block()->mutable_data() != nullptr); + */ cudnnConvolutionForward(ctx->cudnn_handle, &alpha, this->x_desc_, inblock->data(), this->filter_desc_, wblock->data(), this->conv_desc_, this->fp_alg_, @@ -185,6 +198,7 @@ const Tensor CudnnConvolution::Forward(int flag, const Tensor &input) { this->y_desc_, outblock->mutable_data()); }, {input.block(), weight_.block()}, {output.block()}, workspace_.block()); + // LOG(ERROR) << "before bias"; if (bias_term_) { output.device()->Exec([output, this](Context *ctx) { float beta = 1.f, alpha = 1.0f; diff --git a/src/model/layer/opencl_convolution.cc b/src/model/layer/opencl_convolution.cc index c43719ff7a..4b70a714a7 100644 --- a/src/model/layer/opencl_convolution.cc +++ b/src/model/layer/opencl_convolution.cc @@ -22,7 +22,7 @@ namespace singa { -RegisterLayerClass(opencl_convolution, OpenclConvolution); +RegisterLayerClass(singacl_convolution, OpenclConvolution); /// \copydoc Layer::Forward(int flag, const Tensor&) const Tensor OpenclConvolution::Forward(int flag, const Tensor &input) { diff --git a/src/model/layer/opencl_pooling.cc b/src/model/layer/opencl_pooling.cc index 2e3533078f..f123270b9c 100644 --- a/src/model/layer/opencl_pooling.cc +++ b/src/model/layer/opencl_pooling.cc @@ -22,7 +22,7 @@ namespace singa { -RegisterLayerClass(opencl_pooling, OpenclPooling); +RegisterLayerClass(singacl_pooling, OpenclPooling); const Tensor OpenclPooling::Forward(int flag, const Tensor &input) { CHECK(buf_.empty());