aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp16
1 files changed, 11 insertions, 5 deletions
diff --git a/src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp b/src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp
index beff7ae8c4..28d4ff2759 100644
--- a/src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp
+++ b/src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp
@@ -44,7 +44,8 @@ CLDepthwiseIm2ColKernel::CLDepthwiseIm2ColKernel()
namespace
{
-Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int depth_multiplier)
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int depth_multiplier,
+ const Size2D &dilation)
{
const size_t idx_c = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL);
@@ -55,16 +56,18 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c
ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_quantized_asymmetric(input->data_type()) && has_bias);
ARM_COMPUTE_RETURN_ERROR_ON((input->dimension(idx_c) * depth_multiplier) != output->dimension(2));
ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(0) != (kernel_dims.width * kernel_dims.height + ((has_bias) ? 1 : 0)));
+ ARM_COMPUTE_RETURN_ERROR_ON((dilation.x() < 1) || dilation.y() < 1);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
return Status{};
}
} // namespace
-void CLDepthwiseIm2ColKernel::configure(const ICLTensor *input, ICLTensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int depth_multiplier)
+void CLDepthwiseIm2ColKernel::configure(const ICLTensor *input, ICLTensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int depth_multiplier,
+ const Size2D &dilation)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), kernel_dims, conv_info, has_bias, depth_multiplier));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), kernel_dims, conv_info, has_bias, depth_multiplier, dilation));
_input = input;
_output = output;
@@ -88,6 +91,8 @@ void CLDepthwiseIm2ColKernel::configure(const ICLTensor *input, ICLTensor *outpu
build_opts.add_option("-DKERNEL_WIDTH=" + support::cpp11::to_string(kernel_dims.width));
build_opts.add_option("-DKERNEL_HEIGHT=" + support::cpp11::to_string(kernel_dims.height));
build_opts.add_option("-DDEPTH_MULTIPLIER=" + support::cpp11::to_string(depth_multiplier));
+ build_opts.add_option("-DDILATION_X=" + support::cpp11::to_string(dilation.x()));
+ build_opts.add_option("-DDILATION_Y=" + support::cpp11::to_string(dilation.y()));
build_opts.add_option("-D" + string_from_data_layout(input->info()->data_layout()));
build_opts.add_option_if(has_bias, "-DHAS_BIAS");
build_opts.add_option_if_else(is_data_type_quantized_asymmetric(input->info()->data_type()),
@@ -104,9 +109,10 @@ void CLDepthwiseIm2ColKernel::configure(const ICLTensor *input, ICLTensor *outpu
ICLKernel::configure_internal(win);
}
-Status CLDepthwiseIm2ColKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int depth_multiplier)
+Status CLDepthwiseIm2ColKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int depth_multiplier,
+ const Size2D &dilation)
{
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, kernel_dims, conv_info, has_bias, depth_multiplier));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, kernel_dims, conv_info, has_bias, depth_multiplier, dilation));
return Status{};
}