diff options
Diffstat (limited to 'src/backends/neon/NeonLayerSupport.cpp')
-rw-r--r-- | src/backends/neon/NeonLayerSupport.cpp | 30 |
1 files changed, 29 insertions, 1 deletions
diff --git a/src/backends/neon/NeonLayerSupport.cpp b/src/backends/neon/NeonLayerSupport.cpp index 9dc8a01778..853a518b45 100644 --- a/src/backends/neon/NeonLayerSupport.cpp +++ b/src/backends/neon/NeonLayerSupport.cpp @@ -5,6 +5,7 @@ #include "NeonLayerSupport.hpp" #include "NeonBackendId.hpp" +#include "NeonBackendModelContext.hpp" #include <armnn/Descriptors.hpp> #include <armnn/Exceptions.hpp> @@ -15,6 +16,7 @@ #include <InternalTypes.hpp> #include <LayerSupportCommon.hpp> #include <armnn/utility/IgnoreUnused.hpp> +#include <armnn/utility/PolymorphicDowncast.hpp> #if defined(ARMCOMPUTENEON_ENABLED) #include <aclCommon/ArmComputeUtils.hpp> @@ -125,6 +127,16 @@ inline bool IsWorkloadSupported(FuncType& func, Optional<std::string&> reasonIfU #endif } // anonymous namespace +NeonLayerSupport::NeonLayerSupport(const IBackendInternal::IBackendSpecificModelContextPtr& modelContextPtr) + : m_ModelContextPtr(modelContextPtr) +{ +} + +NeonLayerSupport::NeonLayerSupport() + : m_ModelContextPtr(nullptr) +{ +} + bool NeonLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output, Optional<std::string&> reasonIfUnsupported) const @@ -311,13 +323,29 @@ bool NeonLayerSupport::IsConvolution2dSupported(const TensorInfo& input, const Optional<TensorInfo>& biases, Optional<std::string&> reasonIfUnsupported) const { + bool isFastMathEnabled = false; +#if defined(ARMCOMPUTENEON_ENABLED) + if (m_ModelContextPtr) + { + if (m_ModelContextPtr.get() != nullptr) + { + auto modelOptions = armnn::PolymorphicDowncast<NeonBackendModelContext*>(m_ModelContextPtr.get()); + if (modelOptions) + { + isFastMathEnabled = modelOptions->IsFastMathEnabled(); + } + } + } +#endif + FORWARD_WORKLOAD_VALIDATE_FUNC(NeonConvolution2dWorkloadValidate, reasonIfUnsupported, input, output, descriptor, weights, - biases); + biases, + isFastMathEnabled); } bool NeonLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input, |