aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorViet-Hoa Do <viet-hoa.do@arm.com>2023-08-16 14:14:39 +0100
committerViet-Hoa Do <viet-hoa.do@arm.com>2023-08-17 13:58:29 +0000
commit580ecd750ed76c72d59a8b8d23566686e6aa9c7b (patch)
tree77c7094c82cd3dd1eb931b22619a9614decf9421 /src
parent607509748bf81d72a7bd0f54d5fe5f7504c3b6ff (diff)
downloadComputeLibrary-580ecd750ed76c72d59a8b8d23566686e6aa9c7b.tar.gz
Fix depthwise convolution not using assembly kernel
* Take dilation into account when checking padding. Resolves: COMPMID-6348 Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com> Change-Id: I897a13ba7f37382733c35c1701d1ec310ed55331 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10147 Reviewed-by: SiCong Li <sicong.li@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src')
-rw-r--r--src/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.cpp8
1 files changed, 6 insertions, 2 deletions
diff --git a/src/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.cpp b/src/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.cpp
index e092c836af..b503a8b734 100644
--- a/src/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.cpp
+++ b/src/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.cpp
@@ -306,11 +306,15 @@ Status CpuDepthwiseConv2dAssemblyWrapperKernel::validate(const ITensorInfo *src,
// Assembly kernels cannot work with padding greater than the kernel.
const auto &padding = info.pad_stride_info;
+ const auto &dilation = info.dilation;
const auto &wei_shape = weights->tensor_shape();
+ const auto dilated_wei_w = wei_shape[1] + (wei_shape[1] - 1) * (dilation.x() - 1);
+ const auto dilated_wei_h = wei_shape[2] + (wei_shape[2] - 1) * (dilation.y() - 1);
+
ARM_COMPUTE_RETURN_ERROR_ON(
- padding.pad_top() >= wei_shape[2] || padding.pad_bottom() >= wei_shape[2] ||
- padding.pad_left() >= wei_shape[1] || padding.pad_right() >= wei_shape[1]
+ padding.pad_left() >= dilated_wei_w || padding.pad_right() >= dilated_wei_w ||
+ padding.pad_top() >= dilated_wei_h || padding.pad_bottom() >= dilated_wei_h
);
return Status{};