diff options
author | Ferran Balaguer <ferran.balaguer@arm.com> | 2019-06-10 10:29:54 +0100 |
---|---|---|
committer | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-06-14 13:03:28 +0000 |
commit | d73d14fd77fe1405a33b3ecf3c56e1ac65647ff7 (patch) | |
tree | a8f51e7d7c652653dc13c8c978aca347463c03a0 /src/backends/reference/RefWorkloadFactory.cpp | |
parent | 0421e7f22d9ccd5d810b345731b766a96c841492 (diff) | |
download | armnn-d73d14fd77fe1405a33b3ecf3c56e1ac65647ff7.tar.gz |
IVGCVSW-3229 Refactor L2Normalization workload to support multiple data types
Signed-off-by: Ferran Balaguer <ferran.balaguer@arm.com>
Change-Id: I848056aad4b172d432664633eea000843d85a85d
Diffstat (limited to 'src/backends/reference/RefWorkloadFactory.cpp')
-rw-r--r-- | src/backends/reference/RefWorkloadFactory.cpp | 33 |
1 files changed, 16 insertions, 17 deletions
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index cb26f2642b..72762a48e6 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -27,15 +27,16 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::MakeWorkload(const QueueDescripto info); } -bool IsFloat16(const WorkloadInfo& info) +template <DataType ArmnnType> +bool IsDataType(const WorkloadInfo& info) { - auto checkFloat16 = [](const TensorInfo& tensorInfo) {return tensorInfo.GetDataType() == DataType::Float16;}; - auto it = std::find_if(std::begin(info.m_InputTensorInfos), std::end(info.m_InputTensorInfos), checkFloat16); + auto checkType = [](const TensorInfo& tensorInfo) {return tensorInfo.GetDataType() == ArmnnType;}; + auto it = std::find_if(std::begin(info.m_InputTensorInfos), std::end(info.m_InputTensorInfos), checkType); if (it != std::end(info.m_InputTensorInfos)) { return true; } - it = std::find_if(std::begin(info.m_OutputTensorInfos), std::end(info.m_OutputTensorInfos), checkFloat16); + it = std::find_if(std::begin(info.m_OutputTensorInfos), std::end(info.m_OutputTensorInfos), checkType); if (it != std::end(info.m_OutputTensorInfos)) { return true; @@ -43,20 +44,14 @@ bool IsFloat16(const WorkloadInfo& info) return false; } +bool IsFloat16(const WorkloadInfo& info) +{ + return IsDataType<DataType::Float16>(info); +} + bool IsUint8(const WorkloadInfo& info) { - auto checkUint8 = [](const TensorInfo& tensorInfo) {return tensorInfo.GetDataType() == DataType::QuantisedAsymm8;}; - auto it = std::find_if(std::begin(info.m_InputTensorInfos), std::end(info.m_InputTensorInfos), checkUint8); - if (it != std::end(info.m_InputTensorInfos)) - { - return true; - } - it = std::find_if(std::begin(info.m_OutputTensorInfos), std::end(info.m_OutputTensorInfos), checkUint8); - if (it != std::end(info.m_OutputTensorInfos)) - { - return true; - } - return false; + return IsDataType<DataType::QuantisedAsymm8>(info); } RefWorkloadFactory::RefWorkloadFactory() @@ -260,7 +255,11 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateFakeQuantization( std::unique_ptr<IWorkload> RefWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload<RefL2NormalizationFloat32Workload, NullWorkload>(descriptor, info); + if (IsFloat16(info) || IsUint8(info)) + { + return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info); + } + return std::make_unique<RefL2NormalizationWorkload>(descriptor, info); } std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor, |