aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornarpra01 <narumol.prangnawarat@arm.com>2018-09-27 16:46:14 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-10-10 16:16:57 +0100
commit33cea4db0b2729c5dbd50f9c0985578c60baffdd (patch)
treeecea0ce825b52a2225788362c76f40f6ec548cda
parent0dbe0ee25312b728d77383d11c465156e64ae757 (diff)
downloadarmnn-33cea4db0b2729c5dbd50f9c0985578c60baffdd.tar.gz
IVGCVSW-1919 - data layout parameter for Normalization
Change-Id: I33dce72bb0f1e25425dc058d6213a7cdf56eecd2
-rw-r--r--include/armnn/Descriptors.hpp2
-rw-r--r--src/backends/ClWorkloads/ClNormalizationFloatWorkload.cpp17
-rw-r--r--src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp9
3 files changed, 16 insertions, 12 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index 893db69700..bc1b59bdf5 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -262,6 +262,7 @@ struct NormalizationDescriptor
, m_Alpha(0.f)
, m_Beta(0.f)
, m_K(0.f)
+ , m_DataLayout(DataLayout::NCHW)
{}
NormalizationAlgorithmChannel m_NormChannelType;
@@ -270,6 +271,7 @@ struct NormalizationDescriptor
float m_Alpha;
float m_Beta;
float m_K;
+ DataLayout m_DataLayout;
};
struct BatchNormalizationDescriptor
diff --git a/src/backends/ClWorkloads/ClNormalizationFloatWorkload.cpp b/src/backends/ClWorkloads/ClNormalizationFloatWorkload.cpp
index cf2cad16aa..d5863b444c 100644
--- a/src/backends/ClWorkloads/ClNormalizationFloatWorkload.cpp
+++ b/src/backends/ClWorkloads/ClNormalizationFloatWorkload.cpp
@@ -11,17 +11,19 @@
#include <backends/aclCommon/ArmComputeTensorUtils.hpp>
#include "ClWorkloadUtils.hpp"
+using namespace armnn::armcomputetensorutils;
+
namespace armnn
{
-arm_compute::Status ClNormalizationWorkloadValidate(const TensorInfo& input, const TensorInfo& output,
- const NormalizationDescriptor& descriptor)
+arm_compute::Status ClNormalizationWorkloadValidate(const TensorInfo& input,
+ const TensorInfo& output,
+ const NormalizationDescriptor& descriptor)
{
- const arm_compute::TensorInfo aclInputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(input);
- const arm_compute::TensorInfo aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
+ const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input, descriptor.m_DataLayout);
+ const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output, descriptor.m_DataLayout);
- arm_compute::NormalizationLayerInfo layerInfo =
- armcomputetensorutils::BuildArmComputeNormalizationLayerInfo(descriptor);
+ arm_compute::NormalizationLayerInfo layerInfo = BuildArmComputeNormalizationLayerInfo(descriptor);
return arm_compute::CLNormalizationLayer::validate(&aclInputInfo, &aclOutputInfo, layerInfo);
}
@@ -35,8 +37,7 @@ ClNormalizationFloatWorkload::ClNormalizationFloatWorkload(const NormalizationQu
arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
- arm_compute::NormalizationLayerInfo normalizationInfo =
- armcomputetensorutils::BuildArmComputeNormalizationLayerInfo(m_Data.m_Parameters);
+ arm_compute::NormalizationLayerInfo normalizationInfo = BuildArmComputeNormalizationLayerInfo(m_Data.m_Parameters);
m_NormalizationLayer.configure(&input, &output, normalizationInfo);
};
diff --git a/src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp b/src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp
index 472c75f222..9cd315efaa 100644
--- a/src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp
+++ b/src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp
@@ -8,6 +8,8 @@
#include <backends/aclCommon/ArmComputeUtils.hpp>
#include <backends/aclCommon/ArmComputeTensorUtils.hpp>
+using namespace armnn::armcomputetensorutils;
+
namespace armnn
{
@@ -15,11 +17,10 @@ arm_compute::Status NeonNormalizationWorkloadValidate(const TensorInfo& input,
const TensorInfo& output,
const NormalizationDescriptor& descriptor)
{
- const arm_compute::TensorInfo aclInput = armcomputetensorutils::BuildArmComputeTensorInfo(input);
- const arm_compute::TensorInfo aclOutput = armcomputetensorutils::BuildArmComputeTensorInfo(output);
+ const arm_compute::TensorInfo aclInput = BuildArmComputeTensorInfo(input, descriptor.m_DataLayout);
+ const arm_compute::TensorInfo aclOutput = BuildArmComputeTensorInfo(output, descriptor.m_DataLayout);
- arm_compute::NormalizationLayerInfo normalizationInfo =
- armcomputetensorutils::BuildArmComputeNormalizationLayerInfo(descriptor);
+ arm_compute::NormalizationLayerInfo normalizationInfo = BuildArmComputeNormalizationLayerInfo(descriptor);
return arm_compute::NENormalizationLayer::validate(&aclInput, &aclOutput, normalizationInfo);
}