diff options
Diffstat (limited to 'src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp')
-rw-r--r-- | src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp b/src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp index 854ecd3c59..92c0396d86 100644 --- a/src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp +++ b/src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp @@ -4,10 +4,13 @@ // #include "NeonNormalizationFloatWorkload.hpp" -#include <neon/NeonLayerSupport.hpp> + +#include "NeonWorkloadUtils.hpp" #include <aclCommon/ArmComputeUtils.hpp> #include <aclCommon/ArmComputeTensorUtils.hpp> +#include <arm_compute/runtime/NEON/functions/NENormalizationLayer.h> + using namespace armnn::armcomputetensorutils; namespace armnn @@ -57,7 +60,6 @@ NeonNormalizationFloatWorkload::NeonNormalizationFloatWorkload(const Normalizati const WorkloadInfo& info, std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager) : FloatWorkload<NormalizationQueueDescriptor>(descriptor, info) - , m_NormalizationLayer(memoryManager) { m_Data.ValidateInputsOutputs("NeonNormalizationFloatWorkload", 1, 1); std::string reasonIfUnsupported; @@ -89,14 +91,15 @@ NeonNormalizationFloatWorkload::NeonNormalizationFloatWorkload(const Normalizati m_Data.m_Parameters.m_Beta, m_Data.m_Parameters.m_K, false); - - m_NormalizationLayer.configure(&input, &output, normalizationInfo); + auto layer = std::make_unique<arm_compute::NENormalizationLayer>(memoryManager); + layer->configure(&input, &output, normalizationInfo); + m_NormalizationLayer.reset(layer.release()); } void NeonNormalizationFloatWorkload::Execute() const { ARMNN_SCOPED_PROFILING_EVENT_NEON("NeonNormalizationFloatWorkload_Execute"); - m_NormalizationLayer.run(); + m_NormalizationLayer->run(); } } //namespace armnn |