From d73d14fd77fe1405a33b3ecf3c56e1ac65647ff7 Mon Sep 17 00:00:00 2001 From: Ferran Balaguer Date: Mon, 10 Jun 2019 10:29:54 +0100 Subject: IVGCVSW-3229 Refactor L2Normalization workload to support multiple data types Signed-off-by: Ferran Balaguer Change-Id: I848056aad4b172d432664633eea000843d85a85d --- src/backends/reference/RefLayerSupport.cpp | 28 ++++++-- src/backends/reference/RefWorkloadFactory.cpp | 33 +++++----- src/backends/reference/backend.mk | 2 +- .../reference/test/RefCreateWorkloadTests.cpp | 14 +++- src/backends/reference/test/RefLayerTests.cpp | 8 +++ src/backends/reference/workloads/CMakeLists.txt | 4 +- .../RefL2NormalizationFloat32Workload.cpp | 69 -------------------- .../RefL2NormalizationFloat32Workload.hpp | 22 ------- .../workloads/RefL2NormalizationWorkload.cpp | 75 ++++++++++++++++++++++ .../workloads/RefL2NormalizationWorkload.hpp | 23 +++++++ src/backends/reference/workloads/RefWorkloads.hpp | 2 +- 11 files changed, 161 insertions(+), 119 deletions(-) delete mode 100644 src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp delete mode 100644 src/backends/reference/workloads/RefL2NormalizationFloat32Workload.hpp create mode 100644 src/backends/reference/workloads/RefL2NormalizationWorkload.cpp create mode 100644 src/backends/reference/workloads/RefL2NormalizationWorkload.hpp (limited to 'src/backends/reference') diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index b508dfd29d..e42e4242e0 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -743,12 +743,30 @@ bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input, const L2NormalizationDescriptor& descriptor, Optional reasonIfUnsupported) const { - ignore_unused(output); ignore_unused(descriptor); - return IsSupportedForDataTypeRef(reasonIfUnsupported, - input.GetDataType(), - &TrueFunc<>, - &FalseFuncU8<>); + // Define supported types + std::array supportedTypes = + { + DataType::Float32, + DataType::QuantisedSymm16 + }; + + bool supported = true; + + supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported, + "Reference L2normalization: input type not supported."); + + supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported, + "Reference L2normalization: output type not supported."); + + supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported, + "Reference L2normalization: input and output types mismatched."); + + supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported, + "Reference L2normalization: input and output shapes have different " + "num total elements."); + + return supported; } bool RefLayerSupport::IsLstmSupported(const TensorInfo& input, diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index cb26f2642b..72762a48e6 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -27,15 +27,16 @@ std::unique_ptr RefWorkloadFactory::MakeWorkload(const QueueDescripto info); } -bool IsFloat16(const WorkloadInfo& info) +template +bool IsDataType(const WorkloadInfo& info) { - auto checkFloat16 = [](const TensorInfo& tensorInfo) {return tensorInfo.GetDataType() == DataType::Float16;}; - auto it = std::find_if(std::begin(info.m_InputTensorInfos), std::end(info.m_InputTensorInfos), checkFloat16); + auto checkType = [](const TensorInfo& tensorInfo) {return tensorInfo.GetDataType() == ArmnnType;}; + auto it = std::find_if(std::begin(info.m_InputTensorInfos), std::end(info.m_InputTensorInfos), checkType); if (it != std::end(info.m_InputTensorInfos)) { return true; } - it = std::find_if(std::begin(info.m_OutputTensorInfos), std::end(info.m_OutputTensorInfos), checkFloat16); + it = std::find_if(std::begin(info.m_OutputTensorInfos), std::end(info.m_OutputTensorInfos), checkType); if (it != std::end(info.m_OutputTensorInfos)) { return true; @@ -43,20 +44,14 @@ bool IsFloat16(const WorkloadInfo& info) return false; } +bool IsFloat16(const WorkloadInfo& info) +{ + return IsDataType(info); +} + bool IsUint8(const WorkloadInfo& info) { - auto checkUint8 = [](const TensorInfo& tensorInfo) {return tensorInfo.GetDataType() == DataType::QuantisedAsymm8;}; - auto it = std::find_if(std::begin(info.m_InputTensorInfos), std::end(info.m_InputTensorInfos), checkUint8); - if (it != std::end(info.m_InputTensorInfos)) - { - return true; - } - it = std::find_if(std::begin(info.m_OutputTensorInfos), std::end(info.m_OutputTensorInfos), checkUint8); - if (it != std::end(info.m_OutputTensorInfos)) - { - return true; - } - return false; + return IsDataType(info); } RefWorkloadFactory::RefWorkloadFactory() @@ -260,7 +255,11 @@ std::unique_ptr RefWorkloadFactory::CreateFakeQuantization( std::unique_ptr RefWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload(descriptor, info); + if (IsFloat16(info) || IsUint8(info)) + { + return MakeWorkload(descriptor, info); + } + return std::make_unique(descriptor, info); } std::unique_ptr RefWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor, diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk index 0d2b65d433..189f692033 100644 --- a/src/backends/reference/backend.mk +++ b/src/backends/reference/backend.mk @@ -43,7 +43,7 @@ BACKEND_SOURCES := \ workloads/RefFloorWorkload.cpp \ workloads/RefFullyConnectedWorkload.cpp \ workloads/RefGatherWorkload.cpp \ - workloads/RefL2NormalizationFloat32Workload.cpp \ + workloads/RefL2NormalizationWorkload.cpp \ workloads/RefLstmWorkload.cpp \ workloads/RefMeanFloat32Workload.cpp \ workloads/RefMeanUint8Workload.cpp \ diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp index dbcf20169c..3de47d20e4 100644 --- a/src/backends/reference/test/RefCreateWorkloadTests.cpp +++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp @@ -712,12 +712,22 @@ static void RefCreateL2NormalizationTest(DataLayout dataLayout) BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloat32) { - RefCreateL2NormalizationTest(DataLayout::NCHW); + RefCreateL2NormalizationTest(DataLayout::NCHW); } BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloat32Nhwc) { - RefCreateL2NormalizationTest(DataLayout::NHWC); + RefCreateL2NormalizationTest(DataLayout::NHWC); +} + +BOOST_AUTO_TEST_CASE(CreateL2NormalizationInt16) +{ + RefCreateL2NormalizationTest(DataLayout::NCHW); +} + +BOOST_AUTO_TEST_CASE(CreateL2NormalizationInt16Nhwc) +{ + RefCreateL2NormalizationTest(DataLayout::NHWC); } template diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index 8ebb725a6f..30520cbf2e 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -472,11 +472,19 @@ ARMNN_AUTO_TEST_CASE(L2Normalization1d, L2Normalization1dTest, armnn::DataLayout ARMNN_AUTO_TEST_CASE(L2Normalization2d, L2Normalization2dTest, armnn::DataLayout::NCHW) ARMNN_AUTO_TEST_CASE(L2Normalization3d, L2Normalization3dTest, armnn::DataLayout::NCHW) ARMNN_AUTO_TEST_CASE(L2Normalization4d, L2Normalization4dTest, armnn::DataLayout::NCHW) +ARMNN_AUTO_TEST_CASE(L2Normalization1dInt16, L2Normalization1dInt16Test, armnn::DataLayout::NCHW) +ARMNN_AUTO_TEST_CASE(L2Normalization2dInt16, L2Normalization2dInt16Test, armnn::DataLayout::NCHW) +ARMNN_AUTO_TEST_CASE(L2Normalization3dInt16, L2Normalization3dInt16Test, armnn::DataLayout::NCHW) +ARMNN_AUTO_TEST_CASE(L2Normalization4dInt16, L2Normalization4dInt16Test, armnn::DataLayout::NCHW) ARMNN_AUTO_TEST_CASE(L2Normalization1dNhwc, L2Normalization1dTest, armnn::DataLayout::NHWC) ARMNN_AUTO_TEST_CASE(L2Normalization2dNhwc, L2Normalization2dTest, armnn::DataLayout::NHWC) ARMNN_AUTO_TEST_CASE(L2Normalization3dNhwc, L2Normalization3dTest, armnn::DataLayout::NHWC) ARMNN_AUTO_TEST_CASE(L2Normalization4dNhwc, L2Normalization4dTest, armnn::DataLayout::NHWC) +ARMNN_AUTO_TEST_CASE(L2Normalization1dInt16Nhwc, L2Normalization1dInt16Test, armnn::DataLayout::NHWC) +ARMNN_AUTO_TEST_CASE(L2Normalization2dInt16Nhwc, L2Normalization2dInt16Test, armnn::DataLayout::NHWC) +ARMNN_AUTO_TEST_CASE(L2Normalization3dInt16Nhwc, L2Normalization3dInt16Test, armnn::DataLayout::NHWC) +ARMNN_AUTO_TEST_CASE(L2Normalization4dInt16Nhwc, L2Normalization4dInt16Test, armnn::DataLayout::NHWC) // Pad ARMNN_AUTO_TEST_CASE(PadFloat322d, PadFloat322dTest) diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index 4d11447280..41a553482d 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -70,8 +70,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefFullyConnectedWorkload.hpp RefGatherWorkload.cpp RefGatherWorkload.hpp - RefL2NormalizationFloat32Workload.cpp - RefL2NormalizationFloat32Workload.hpp + RefL2NormalizationWorkload.cpp + RefL2NormalizationWorkload.hpp RefLstmWorkload.cpp RefLstmWorkload.hpp RefConcatWorkload.cpp diff --git a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp b/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp deleted file mode 100644 index bc82739f6e..0000000000 --- a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp +++ /dev/null @@ -1,69 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "RefL2NormalizationFloat32Workload.hpp" - -#include "RefWorkloadUtils.hpp" -#include "TensorBufferArrayView.hpp" - -#include "Profiling.hpp" - -#include - -using namespace armnnUtils; - -namespace armnn -{ - -void RefL2NormalizationFloat32Workload::Execute() const -{ - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefL2NormalizationFloat32Workload_Execute"); - - const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); - const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); - - TensorBufferArrayView input(inputInfo.GetShape(), - GetInputTensorDataFloat(0, m_Data), - m_Data.m_Parameters.m_DataLayout); - TensorBufferArrayView output(outputInfo.GetShape(), - GetOutputTensorDataFloat(0, m_Data), - m_Data.m_Parameters.m_DataLayout); - - DataLayoutIndexed dataLayout(m_Data.m_Parameters.m_DataLayout); - - const unsigned int batches = inputInfo.GetShape()[0]; - const unsigned int channels = inputInfo.GetShape()[dataLayout.GetChannelsIndex()]; - const unsigned int height = inputInfo.GetShape()[dataLayout.GetHeightIndex()]; - const unsigned int width = inputInfo.GetShape()[dataLayout.GetWidthIndex()]; - - for (unsigned int n = 0; n < batches; ++n) - { - for (unsigned int c = 0; c < channels; ++c) - { - for (unsigned int h = 0; h < height; ++h) - { - for (unsigned int w = 0; w < width; ++w) - { - float reduction = 0.0; - for (unsigned int d = 0; d < channels; ++d) - { - const float value = input.Get(n, d, h, w); - reduction += value * value; - } - - // Using std::max(reduction, epsilon) below would prevent against division by 0. - // However, at the time of writing: - // - This is not supported by the ACL functions used to implement L2Normalization in the CL - // backend. - // - The reference semantics for this operator do not include this parameter. - const float scale = 1.0f / sqrtf(reduction); - output.Get(n, c, h, w) = input.Get(n, c, h, w) * scale; - } - } - } - } -} - -} //namespace armnn diff --git a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.hpp b/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.hpp deleted file mode 100644 index 50ece0e905..0000000000 --- a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.hpp +++ /dev/null @@ -1,22 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#pragma once - -#include -#include - -namespace armnn -{ - -class RefL2NormalizationFloat32Workload : public Float32Workload -{ -public: - using Float32Workload::Float32Workload; - - void Execute() const override; -}; - -} //namespace armnn diff --git a/src/backends/reference/workloads/RefL2NormalizationWorkload.cpp b/src/backends/reference/workloads/RefL2NormalizationWorkload.cpp new file mode 100644 index 0000000000..ce5699ef0b --- /dev/null +++ b/src/backends/reference/workloads/RefL2NormalizationWorkload.cpp @@ -0,0 +1,75 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefL2NormalizationWorkload.hpp" + +#include "RefWorkloadUtils.hpp" +#include "Decoders.hpp" +#include "Encoders.hpp" +#include "DataLayoutIndexed.hpp" + + +#include "Profiling.hpp" + +#include + +using namespace armnnUtils; + +namespace armnn +{ +RefL2NormalizationWorkload::RefL2NormalizationWorkload( + const L2NormalizationQueueDescriptor& descriptor, + const WorkloadInfo& info) + : BaseWorkload(descriptor, info) {} + + void RefL2NormalizationWorkload::Execute() const + { + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefL2NormalizationWorkload_Execute"); + + const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); + const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); + + auto inputDecoder = MakeDecoder(inputInfo, m_Data.m_Inputs[0]->Map()); + auto outputEncoder = MakeEncoder(outputInfo, m_Data.m_Outputs[0]->Map()); + + DataLayoutIndexed dataLayout(m_Data.m_Parameters.m_DataLayout); + + const unsigned int batches = inputInfo.GetShape()[0]; + const unsigned int channels = inputInfo.GetShape()[dataLayout.GetChannelsIndex()]; + const unsigned int height = inputInfo.GetShape()[dataLayout.GetHeightIndex()]; + const unsigned int width = inputInfo.GetShape()[dataLayout.GetWidthIndex()]; + + for (unsigned int n = 0; n < batches; ++n) + { + for (unsigned int c = 0; c < channels; ++c) + { + for (unsigned int h = 0; h < height; ++h) + { + for (unsigned int w = 0; w < width; ++w) + { + float reduction = 0.0; + for (unsigned int d = 0; d < channels; ++d) + { + unsigned int inputIndex = dataLayout.GetIndex(inputInfo.GetShape(), n, d, h, w); + + (*inputDecoder)[inputIndex]; + const float value = inputDecoder->Get(); + reduction += value * value; + } + + unsigned int index = dataLayout.GetIndex(inputInfo.GetShape(), n, c, h, w); + + const float scale = 1.0f / sqrtf(reduction); + + (*inputDecoder)[index]; + (*outputEncoder)[index]; + outputEncoder->Set(inputDecoder->Get() * scale); + } + } + } + } + } + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefL2NormalizationWorkload.hpp b/src/backends/reference/workloads/RefL2NormalizationWorkload.hpp new file mode 100644 index 0000000000..4beedc9992 --- /dev/null +++ b/src/backends/reference/workloads/RefL2NormalizationWorkload.hpp @@ -0,0 +1,23 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include +#include + +namespace armnn +{ + +class RefL2NormalizationWorkload : public BaseWorkload +{ +public: + explicit RefL2NormalizationWorkload(const L2NormalizationQueueDescriptor& descriptor, + const WorkloadInfo& info); + + void Execute() const override; +}; + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index 53f7aa2efb..1a2dec402e 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -12,7 +12,7 @@ #include "RefConvolution2dWorkload.hpp" #include "RefSplitterWorkload.hpp" #include "RefResizeBilinearUint8Workload.hpp" -#include "RefL2NormalizationFloat32Workload.hpp" +#include "RefL2NormalizationWorkload.hpp" #include "RefActivationWorkload.hpp" #include "RefPooling2dWorkload.hpp" #include "RefWorkloadUtils.hpp" -- cgit v1.2.1