From 81a26ad6b626ce2da83659d7c6c17b6104d1f203 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Mon, 23 Oct 2017 20:29:30 +0100 Subject: COMPMID-643: Add bias to CLDepthwiseConvolution. Change-Id: Ibfe7b8c1172d10cbcae7971fe86b82090519d31d Reviewed-on: http://mpd-gerrit.cambridge.arm.com/92798 Tested-by: Kaizen Reviewed-by: Jaroslaw Rzepecki Reviewed-by: Anthony Barbier --- src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) (limited to 'src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp') diff --git a/src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp b/src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp index 5c7fe7e0b4..743cd4a38f 100644 --- a/src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp +++ b/src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp @@ -41,13 +41,13 @@ CLDepthwiseIm2ColKernel::CLDepthwiseIm2ColKernel() { } -void CLDepthwiseIm2ColKernel::configure(const ICLTensor *input, ICLTensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info) +void CLDepthwiseIm2ColKernel::configure(const ICLTensor *input, ICLTensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias) { ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32); ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output); ARM_COMPUTE_ERROR_ON(input->info()->dimension(2) != output->info()->dimension(2)); - ARM_COMPUTE_ERROR_ON(output->info()->dimension(0) != (kernel_dims.width * kernel_dims.height)); + ARM_COMPUTE_ERROR_ON(output->info()->dimension(0) != (kernel_dims.width * kernel_dims.height + ((has_bias) ? 1 : 0))); _input = input; _output = output; @@ -66,7 +66,10 @@ void CLDepthwiseIm2ColKernel::configure(const ICLTensor *input, ICLTensor *outpu build_opts.emplace("-DSRC_HEIGHT=" + support::cpp11::to_string(input->info()->dimension(1))); build_opts.emplace("-DKERNEL_WIDTH=" + support::cpp11::to_string(kernel_dims.width)); build_opts.emplace("-DKERNEL_HEIGHT=" + support::cpp11::to_string(kernel_dims.height)); - + if(has_bias) + { + build_opts.emplace("-DHAS_BIAS"); + } _kernel = static_cast(CLKernelLibrary::get().create_kernel("depthwise_im2col", build_opts)); // Configure kernel window -- cgit v1.2.1