aboutsummaryrefslogtreecommitdiff
path: root/src/backends/ClWorkloads/ClNormalizationFloatWorkload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/ClWorkloads/ClNormalizationFloatWorkload.cpp')
-rw-r--r--src/backends/ClWorkloads/ClNormalizationFloatWorkload.cpp17
1 files changed, 9 insertions, 8 deletions
diff --git a/src/backends/ClWorkloads/ClNormalizationFloatWorkload.cpp b/src/backends/ClWorkloads/ClNormalizationFloatWorkload.cpp
index cf2cad16aa..d5863b444c 100644
--- a/src/backends/ClWorkloads/ClNormalizationFloatWorkload.cpp
+++ b/src/backends/ClWorkloads/ClNormalizationFloatWorkload.cpp
@@ -11,17 +11,19 @@
#include <backends/aclCommon/ArmComputeTensorUtils.hpp>
#include "ClWorkloadUtils.hpp"
+using namespace armnn::armcomputetensorutils;
+
namespace armnn
{
-arm_compute::Status ClNormalizationWorkloadValidate(const TensorInfo& input, const TensorInfo& output,
- const NormalizationDescriptor& descriptor)
+arm_compute::Status ClNormalizationWorkloadValidate(const TensorInfo& input,
+ const TensorInfo& output,
+ const NormalizationDescriptor& descriptor)
{
- const arm_compute::TensorInfo aclInputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(input);
- const arm_compute::TensorInfo aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
+ const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input, descriptor.m_DataLayout);
+ const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output, descriptor.m_DataLayout);
- arm_compute::NormalizationLayerInfo layerInfo =
- armcomputetensorutils::BuildArmComputeNormalizationLayerInfo(descriptor);
+ arm_compute::NormalizationLayerInfo layerInfo = BuildArmComputeNormalizationLayerInfo(descriptor);
return arm_compute::CLNormalizationLayer::validate(&aclInputInfo, &aclOutputInfo, layerInfo);
}
@@ -35,8 +37,7 @@ ClNormalizationFloatWorkload::ClNormalizationFloatWorkload(const NormalizationQu
arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
- arm_compute::NormalizationLayerInfo normalizationInfo =
- armcomputetensorutils::BuildArmComputeNormalizationLayerInfo(m_Data.m_Parameters);
+ arm_compute::NormalizationLayerInfo normalizationInfo = BuildArmComputeNormalizationLayerInfo(m_Data.m_Parameters);
m_NormalizationLayer.configure(&input, &output, normalizationInfo);
};