aboutsummaryrefslogtreecommitdiff
path: root/src/backends/neon/workloads
diff options
context:
space:
mode:
authorFrancisMurtagh <francis.murtagh@arm.com>2018-11-29 17:13:36 +0000
committerFrancis Murtagh <francis.murtagh@arm.com>2018-12-03 09:45:51 +0000
commita1b463f343befac766b6dd886aa4624dc381677a (patch)
treebec896aaeb157b17881d5a68bd7daec69c93612f /src/backends/neon/workloads
parent6460c274bc81733e9f36d1719a7db98a74d8db1f (diff)
downloadarmnn-a1b463f343befac766b6dd886aa4624dc381677a.tar.gz
IVGCVSW-2118 L2Normalization ACL function used for Neon
* Changed NeonL2Normalisation to use NEL2NormalizeLayer to normalise along the channel axis in either NCHW or NHWC format Change-Id: Ibaf119b6a3de3c0f80f94b1c5fe9a356cf1fbd0e
Diffstat (limited to 'src/backends/neon/workloads')
-rw-r--r--src/backends/neon/workloads/NeonL2NormalizationFloatWorkload.cpp12
-rw-r--r--src/backends/neon/workloads/NeonL2NormalizationFloatWorkload.hpp3
2 files changed, 6 insertions, 9 deletions
diff --git a/src/backends/neon/workloads/NeonL2NormalizationFloatWorkload.cpp b/src/backends/neon/workloads/NeonL2NormalizationFloatWorkload.cpp
index df8caefbd2..afaa700624 100644
--- a/src/backends/neon/workloads/NeonL2NormalizationFloatWorkload.cpp
+++ b/src/backends/neon/workloads/NeonL2NormalizationFloatWorkload.cpp
@@ -17,10 +17,9 @@ arm_compute::Status NeonL2NormalizationWorkloadValidate(const TensorInfo& input,
const arm_compute::TensorInfo aclInput = BuildArmComputeTensorInfo(input, descriptor.m_DataLayout);
const arm_compute::TensorInfo aclOutput = BuildArmComputeTensorInfo(output, descriptor.m_DataLayout);
- arm_compute::NormalizationLayerInfo normalizationInfo =
- CreateAclNormalizationLayerInfoForL2Normalization(input, descriptor.m_DataLayout);
+ unsigned int axis = (descriptor.m_DataLayout == DataLayout::NCHW) ? 2 : 0;
- return arm_compute::NENormalizationLayer::validate(&aclInput, &aclOutput, normalizationInfo);
+ return arm_compute::NEL2NormalizeLayer::validate(&aclInput, &aclOutput, axis);
}
NeonL2NormalizationFloatWorkload::NeonL2NormalizationFloatWorkload(const L2NormalizationQueueDescriptor& descriptor,
@@ -37,10 +36,9 @@ NeonL2NormalizationFloatWorkload::NeonL2NormalizationFloatWorkload(const L2Norma
input.info()->set_data_layout(aclDataLayout);
output.info()->set_data_layout(aclDataLayout);
- m_Layer.configure(&input,
- &output,
- CreateAclNormalizationLayerInfoForL2Normalization(
- info.m_InputTensorInfos[0], m_Data.m_Parameters.m_DataLayout));
+ unsigned int axis = (m_Data.m_Parameters.m_DataLayout == DataLayout::NCHW) ? 2 : 0;
+
+ m_Layer.configure(&input, &output, axis);
}
void NeonL2NormalizationFloatWorkload::Execute() const
diff --git a/src/backends/neon/workloads/NeonL2NormalizationFloatWorkload.hpp b/src/backends/neon/workloads/NeonL2NormalizationFloatWorkload.hpp
index 35d0282414..30058c571f 100644
--- a/src/backends/neon/workloads/NeonL2NormalizationFloatWorkload.hpp
+++ b/src/backends/neon/workloads/NeonL2NormalizationFloatWorkload.hpp
@@ -25,8 +25,7 @@ public:
virtual void Execute() const override;
private:
- // Purposely not a NEL2Normalize function. See constructor.
- mutable arm_compute::NENormalizationLayer m_Layer;
+ mutable arm_compute::NEL2NormalizeLayer m_Layer;
};
} //namespace armnn