diff options
Diffstat (limited to 'src/backends/neon')
-rw-r--r-- | src/backends/neon/workloads/NeonDepthwiseConvolutionWorkload.cpp | 22 |
1 files changed, 17 insertions, 5 deletions
diff --git a/src/backends/neon/workloads/NeonDepthwiseConvolutionWorkload.cpp b/src/backends/neon/workloads/NeonDepthwiseConvolutionWorkload.cpp index 400ae18807..18085edab5 100644 --- a/src/backends/neon/workloads/NeonDepthwiseConvolutionWorkload.cpp +++ b/src/backends/neon/workloads/NeonDepthwiseConvolutionWorkload.cpp @@ -113,11 +113,23 @@ NeonDepthwiseConvolutionWorkload::NeonDepthwiseConvolutionWorkload( arm_compute::PadStrideInfo padStrideInfo = BuildArmComputePadStrideInfo(m_Data.m_Parameters); - // Check for optimisation opportunities. - const bool use3x3Optimisation = (weightInfo.GetShape()[2] == 3) && (weightInfo.GetShape()[3] == 3); - const bool use5x5Optimisation = (weightInfo.GetShape()[2] == 5) && (weightInfo.GetShape()[3] == 5); - - if (use3x3Optimisation||use5x5Optimisation) + const arm_compute::ITensorInfo* inputInfo = input.info(); + const arm_compute::ITensorInfo* kernelInfo = m_KernelTensor->info(); + const arm_compute::ITensorInfo* biasInfo = m_BiasTensor ? m_BiasTensor->info() : nullptr; + const arm_compute::ITensorInfo* outputInfo = output.info(); + + // Check for optimisation opportunities + arm_compute::Status optimizationStatus = + arm_compute::NEDepthwiseConvolutionLayerOptimized::validate(inputInfo, + kernelInfo, + biasInfo, + outputInfo, + padStrideInfo, + depthMultiplier, + arm_compute::ActivationLayerInfo(), + aclDilationInfo); + + if (optimizationStatus.error_code() == arm_compute::ErrorCode::OK) { m_pDepthwiseConvolutionLayer = std::make_unique<arm_compute::NEDepthwiseConvolutionLayerOptimized>(); static_cast<arm_compute::NEDepthwiseConvolutionLayerOptimized*>( |