diff options
Diffstat (limited to 'src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp b/src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp index 5c7fe7e0b4..743cd4a38f 100644 --- a/src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp +++ b/src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp @@ -41,13 +41,13 @@ CLDepthwiseIm2ColKernel::CLDepthwiseIm2ColKernel() { } -void CLDepthwiseIm2ColKernel::configure(const ICLTensor *input, ICLTensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info) +void CLDepthwiseIm2ColKernel::configure(const ICLTensor *input, ICLTensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias) { ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32); ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output); ARM_COMPUTE_ERROR_ON(input->info()->dimension(2) != output->info()->dimension(2)); - ARM_COMPUTE_ERROR_ON(output->info()->dimension(0) != (kernel_dims.width * kernel_dims.height)); + ARM_COMPUTE_ERROR_ON(output->info()->dimension(0) != (kernel_dims.width * kernel_dims.height + ((has_bias) ? 1 : 0))); _input = input; _output = output; @@ -66,7 +66,10 @@ void CLDepthwiseIm2ColKernel::configure(const ICLTensor *input, ICLTensor *outpu build_opts.emplace("-DSRC_HEIGHT=" + support::cpp11::to_string(input->info()->dimension(1))); build_opts.emplace("-DKERNEL_WIDTH=" + support::cpp11::to_string(kernel_dims.width)); build_opts.emplace("-DKERNEL_HEIGHT=" + support::cpp11::to_string(kernel_dims.height)); - + if(has_bias) + { + build_opts.emplace("-DHAS_BIAS"); + } _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("depthwise_im2col", build_opts)); // Configure kernel window |