From 5488cfaaa1a411cb5a18c81a98b90c6e3011abdc Mon Sep 17 00:00:00 2001 From: Matthew Jackson Date: Fri, 2 Aug 2019 14:53:10 +0100 Subject: IVGCVSW-3608 Fix Neon depthwise convolution 5x5 failure * Fix issued caused by layers with 5x5 filters and depth multipliers > 1 Signed-off-by: Matthew Jackson Signed-off-by: Aron Virginas-Tar Change-Id: I58435a1f0e3c7e69861dc130fad525a01e2a849d --- .../workloads/NeonDepthwiseConvolutionWorkload.cpp | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/backends/neon/workloads/NeonDepthwiseConvolutionWorkload.cpp b/src/backends/neon/workloads/NeonDepthwiseConvolutionWorkload.cpp index 400ae18807..18085edab5 100644 --- a/src/backends/neon/workloads/NeonDepthwiseConvolutionWorkload.cpp +++ b/src/backends/neon/workloads/NeonDepthwiseConvolutionWorkload.cpp @@ -113,11 +113,23 @@ NeonDepthwiseConvolutionWorkload::NeonDepthwiseConvolutionWorkload( arm_compute::PadStrideInfo padStrideInfo = BuildArmComputePadStrideInfo(m_Data.m_Parameters); - // Check for optimisation opportunities. - const bool use3x3Optimisation = (weightInfo.GetShape()[2] == 3) && (weightInfo.GetShape()[3] == 3); - const bool use5x5Optimisation = (weightInfo.GetShape()[2] == 5) && (weightInfo.GetShape()[3] == 5); - - if (use3x3Optimisation||use5x5Optimisation) + const arm_compute::ITensorInfo* inputInfo = input.info(); + const arm_compute::ITensorInfo* kernelInfo = m_KernelTensor->info(); + const arm_compute::ITensorInfo* biasInfo = m_BiasTensor ? m_BiasTensor->info() : nullptr; + const arm_compute::ITensorInfo* outputInfo = output.info(); + + // Check for optimisation opportunities + arm_compute::Status optimizationStatus = + arm_compute::NEDepthwiseConvolutionLayerOptimized::validate(inputInfo, + kernelInfo, + biasInfo, + outputInfo, + padStrideInfo, + depthMultiplier, + arm_compute::ActivationLayerInfo(), + aclDilationInfo); + + if (optimizationStatus.error_code() == arm_compute::ErrorCode::OK) { m_pDepthwiseConvolutionLayer = std::make_unique(); static_cast( -- cgit v1.2.1