aboutsummaryrefslogtreecommitdiff
path: root/src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp')
-rw-r--r--src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp9
1 files changed, 5 insertions, 4 deletions
diff --git a/src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp b/src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp
index 472c75f222..9cd315efaa 100644
--- a/src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp
+++ b/src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp
@@ -8,6 +8,8 @@
#include <backends/aclCommon/ArmComputeUtils.hpp>
#include <backends/aclCommon/ArmComputeTensorUtils.hpp>
+using namespace armnn::armcomputetensorutils;
+
namespace armnn
{
@@ -15,11 +17,10 @@ arm_compute::Status NeonNormalizationWorkloadValidate(const TensorInfo& input,
const TensorInfo& output,
const NormalizationDescriptor& descriptor)
{
- const arm_compute::TensorInfo aclInput = armcomputetensorutils::BuildArmComputeTensorInfo(input);
- const arm_compute::TensorInfo aclOutput = armcomputetensorutils::BuildArmComputeTensorInfo(output);
+ const arm_compute::TensorInfo aclInput = BuildArmComputeTensorInfo(input, descriptor.m_DataLayout);
+ const arm_compute::TensorInfo aclOutput = BuildArmComputeTensorInfo(output, descriptor.m_DataLayout);
- arm_compute::NormalizationLayerInfo normalizationInfo =
- armcomputetensorutils::BuildArmComputeNormalizationLayerInfo(descriptor);
+ arm_compute::NormalizationLayerInfo normalizationInfo = BuildArmComputeNormalizationLayerInfo(descriptor);
return arm_compute::NENormalizationLayer::validate(&aclInput, &aclOutput, normalizationInfo);
}