diff options
Diffstat (limited to 'src/backends/neon/workloads/NeonBatchNormalizationWorkload.cpp')
-rw-r--r-- | src/backends/neon/workloads/NeonBatchNormalizationWorkload.cpp | 26 |
1 files changed, 16 insertions, 10 deletions
diff --git a/src/backends/neon/workloads/NeonBatchNormalizationWorkload.cpp b/src/backends/neon/workloads/NeonBatchNormalizationWorkload.cpp index 44d5035431..fc80f413e8 100644 --- a/src/backends/neon/workloads/NeonBatchNormalizationWorkload.cpp +++ b/src/backends/neon/workloads/NeonBatchNormalizationWorkload.cpp @@ -4,9 +4,13 @@ // #include "NeonBatchNormalizationWorkload.hpp" + +#include "NeonWorkloadUtils.hpp" + #include <backendsCommon/CpuTensorHandle.hpp> #include <aclCommon/ArmComputeTensorUtils.hpp> -#include <armnn/ArmNN.hpp> + +#include <arm_compute/runtime/NEON/functions/NEBatchNormalizationLayer.h> namespace armnn { @@ -68,13 +72,15 @@ NeonBatchNormalizationWorkload::NeonBatchNormalizationWorkload( m_Beta = std::make_unique<arm_compute::Tensor>(); BuildArmComputeTensor(*m_Beta, m_Data.m_Beta->GetTensorInfo()); - m_Layer.configure(&input, - &output, - m_Mean.get(), - m_Variance.get(), - m_Beta.get(), - m_Gamma.get(), - m_Data.m_Parameters.m_Eps); + auto layer = std::make_unique<arm_compute::NEBatchNormalizationLayer>(); + layer->configure(&input, + &output, + m_Mean.get(), + m_Variance.get(), + m_Beta.get(), + m_Gamma.get(), + m_Data.m_Parameters.m_Eps); + m_Layer.reset(layer.release()); InitializeArmComputeTensorData(*m_Mean, m_Data.m_Mean); InitializeArmComputeTensorData(*m_Variance, m_Data.m_Variance); @@ -83,14 +89,14 @@ NeonBatchNormalizationWorkload::NeonBatchNormalizationWorkload( // Force Compute Library to perform the necessary copying and reshaping, after which // delete all the input tensors that will no longer be needed - m_Layer.prepare(); + m_Layer->prepare(); FreeUnusedTensors(); } void NeonBatchNormalizationWorkload::Execute() const { ARMNN_SCOPED_PROFILING_EVENT_NEON("NeonBatchNormalizationWorkload_Execute"); - m_Layer.run(); + m_Layer->run(); } void NeonBatchNormalizationWorkload::FreeUnusedTensors() |