aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/backends/neon/workloads/NeonDepthwiseConvolutionWorkload.cpp22
1 files changed, 17 insertions, 5 deletions
diff --git a/src/backends/neon/workloads/NeonDepthwiseConvolutionWorkload.cpp b/src/backends/neon/workloads/NeonDepthwiseConvolutionWorkload.cpp
index 400ae18807..18085edab5 100644
--- a/src/backends/neon/workloads/NeonDepthwiseConvolutionWorkload.cpp
+++ b/src/backends/neon/workloads/NeonDepthwiseConvolutionWorkload.cpp
@@ -113,11 +113,23 @@ NeonDepthwiseConvolutionWorkload::NeonDepthwiseConvolutionWorkload(
arm_compute::PadStrideInfo padStrideInfo = BuildArmComputePadStrideInfo(m_Data.m_Parameters);
- // Check for optimisation opportunities.
- const bool use3x3Optimisation = (weightInfo.GetShape()[2] == 3) && (weightInfo.GetShape()[3] == 3);
- const bool use5x5Optimisation = (weightInfo.GetShape()[2] == 5) && (weightInfo.GetShape()[3] == 5);
-
- if (use3x3Optimisation||use5x5Optimisation)
+ const arm_compute::ITensorInfo* inputInfo = input.info();
+ const arm_compute::ITensorInfo* kernelInfo = m_KernelTensor->info();
+ const arm_compute::ITensorInfo* biasInfo = m_BiasTensor ? m_BiasTensor->info() : nullptr;
+ const arm_compute::ITensorInfo* outputInfo = output.info();
+
+ // Check for optimisation opportunities
+ arm_compute::Status optimizationStatus =
+ arm_compute::NEDepthwiseConvolutionLayerOptimized::validate(inputInfo,
+ kernelInfo,
+ biasInfo,
+ outputInfo,
+ padStrideInfo,
+ depthMultiplier,
+ arm_compute::ActivationLayerInfo(),
+ aclDilationInfo);
+
+ if (optimizationStatus.error_code() == arm_compute::ErrorCode::OK)
{
m_pDepthwiseConvolutionLayer = std::make_unique<arm_compute::NEDepthwiseConvolutionLayerOptimized>();
static_cast<arm_compute::NEDepthwiseConvolutionLayerOptimized*>(