diff options
author | Idriss Chaouch <idriss.chaouch@arm.com> | 2023-08-28 14:28:31 +0100 |
---|---|---|
committer | Idriss Chaouch <idriss.chaouch@arm.com> | 2023-08-31 11:26:28 +0100 |
commit | 98e383eadf4e670d057ad725c7fe7924fea8e36b (patch) | |
tree | 35acac15aa69ab405887289cb9674d388f06f96b /src/backends/reference | |
parent | 2be039bce38a4fa436e8310dfe14ebfff20d57bd (diff) | |
download | armnn-98e383eadf4e670d057ad725c7fe7924fea8e36b.tar.gz |
IVGCVSW-7525 Add broadcast_to operator
Signed-off-by: Idriss Chaouch <idriss.chaouch@arm.com>
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: I94ec5f9120b2d736fdf98d00ec5137a4efd739b8
Diffstat (limited to 'src/backends/reference')
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 53 | ||||
-rw-r--r-- | src/backends/reference/RefLayerSupport.hpp | 5 | ||||
-rw-r--r-- | src/backends/reference/RefWorkloadFactory.cpp | 5 | ||||
-rw-r--r-- | src/backends/reference/backend.mk | 1 | ||||
-rw-r--r-- | src/backends/reference/test/RefEndToEndTests.cpp | 12 | ||||
-rw-r--r-- | src/backends/reference/test/RefLayerTests.cpp | 56 | ||||
-rw-r--r-- | src/backends/reference/workloads/CMakeLists.txt | 2 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefBroadcastToWorkload.cpp | 48 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefBroadcastToWorkload.hpp | 25 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefWorkloads.hpp | 1 |
10 files changed, 199 insertions, 9 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 0b1b9c7824..defdf0d807 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -100,6 +100,11 @@ bool RefLayerSupport::IsLayerSupported(const LayerType& type, infos[1], *(PolymorphicDowncast<const BatchToSpaceNdDescriptor*>(&descriptor)), reasonIfUnsupported); + case LayerType::BroadcastTo: + return IsBroadcastToSupported(infos[0], + infos[1], + *(PolymorphicDowncast<const BroadcastToDescriptor*>(&descriptor)), + reasonIfUnsupported); case LayerType::Comparison: return IsComparisonSupported(infos[0], infos[1], @@ -807,20 +812,50 @@ bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input, return supported; } +bool RefLayerSupport::IsBroadcastToSupported(const TensorInfo& input, + const TensorInfo& output, + const BroadcastToDescriptor& descriptor, + Optional<std::string&> reasonIfUnsupported) const +{ + IgnoreUnused(descriptor); + + bool supported = true; + + std::array<DataType, 8> supportedTypes + { + DataType::Float32, + DataType::Float16, + DataType::QAsymmS8, + DataType::QAsymmU8, + DataType::QSymmS8, + DataType::QSymmS16, + DataType::Signed32, + DataType::Signed64 + }; + + supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported, + "BroadcastTo: input type not supported."); + + supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported, + "BroadcastTo: output type not supported"); + + return supported; +} + bool RefLayerSupport::IsCastSupported(const TensorInfo& input, const TensorInfo& output, Optional<std::string&> reasonIfUnsupported) const { std::array<DataType, 9> supportedInputTypes = - { - DataType::Float32, - DataType::Float16, - DataType::QSymmS8, - DataType::QAsymmS8, - DataType::QAsymmU8, - DataType::QSymmS16, - DataType::Signed32 - }; + { + DataType::Float32, + DataType::Float16, + DataType::QSymmS8, + DataType::QAsymmS8, + DataType::QAsymmU8, + DataType::QSymmS16, + DataType::Signed32 + }; bool supported = true; supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported, diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp index 377afac62f..9e7175389a 100644 --- a/src/backends/reference/RefLayerSupport.hpp +++ b/src/backends/reference/RefLayerSupport.hpp @@ -54,6 +54,11 @@ public: const BatchToSpaceNdDescriptor& descriptor, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const; + bool IsBroadcastToSupported(const TensorInfo& input, + const TensorInfo& output, + const BroadcastToDescriptor& descriptor, + Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const; + bool IsCastSupported(const TensorInfo& input, const TensorInfo& output, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const; diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index fa2082d4f2..c4d9583a66 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -179,6 +179,11 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateWorkload(LayerType type, = PolymorphicDowncast<const BatchToSpaceNdQueueDescriptor*>(&descriptor); return std::make_unique<RefBatchToSpaceNdWorkload>(*batchToSpaceNdQueueDescriptor, info); } + case LayerType::BroadcastTo: + { + auto broadcastToQueueDescriptor = PolymorphicDowncast<const BroadcastToQueueDescriptor*>(&descriptor); + return std::make_unique<RefBroadcastToWorkload>(*broadcastToQueueDescriptor, info); + } case LayerType::Cast : { auto castQueueDescriptor = PolymorphicDowncast<const CastQueueDescriptor*>(&descriptor); diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk index 7f047af930..27ca8f607a 100644 --- a/src/backends/reference/backend.mk +++ b/src/backends/reference/backend.mk @@ -53,6 +53,7 @@ BACKEND_SOURCES := \ workloads/RefBatchMatMulWorkload.cpp \ workloads/RefBatchNormalizationWorkload.cpp \ workloads/RefBatchToSpaceNdWorkload.cpp \ + workloads/RefBroadcastToWorkload.cpp \ workloads/RefCastWorkload.cpp \ workloads/RefChannelShuffleWorkload.cpp \ workloads/RefComparisonWorkload.cpp \ diff --git a/src/backends/reference/test/RefEndToEndTests.cpp b/src/backends/reference/test/RefEndToEndTests.cpp index 09d6ac5d20..e503d3fb7f 100644 --- a/src/backends/reference/test/RefEndToEndTests.cpp +++ b/src/backends/reference/test/RefEndToEndTests.cpp @@ -10,6 +10,7 @@ #include <backendsCommon/test/ArgMinMaxEndToEndTestImpl.hpp> #include <backendsCommon/test/BatchToSpaceNdEndToEndTestImpl.hpp> #include <backendsCommon/test/BatchMatMulEndToEndTestImpl.hpp> +#include <backendsCommon/test/BroadcastToEndToEndTestImpl.hpp> #include <backendsCommon/test/ChannelShuffleEndToEndTestImpl.hpp> #include <backendsCommon/test/ComparisonEndToEndTestImpl.hpp> #include <backendsCommon/test/ConcatEndToEndTestImpl.hpp> @@ -1728,4 +1729,15 @@ TEST_CASE("RefReshapeRemovalNCHWSecondEndToEnd") { ReshapeRemovalNCHWEndToEnd<armnn::DataType::Float32>(defaultBackends, true, false); } + +// BroadcastTo +TEST_CASE("RefBroadcastToEndToEndFloat32") +{ + BroadcastToEndToEnd<armnn::DataType::Float32>(defaultBackends); +} + +TEST_CASE("RefBroadcastToEndToEndWithElementWiseBinaryFloat32") +{ + BroadcastToEndToEndElementWiseBinary<armnn::DataType::Float32>(defaultBackends); +} } diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index a079bb712a..af4ed966b2 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -2823,4 +2823,60 @@ ARMNN_AUTO_TEST_CASE_WITH_THF(UnidirectionalSequenceLstmLayerInt8NoCifgWithPeeph ARMNN_AUTO_TEST_CASE_WITH_THF(UnidirectionalSequenceLstmInt8WithCifgWithPeepholeNoProjection, UnidirectionalSequenceLstmInt8WithCifgWithPeepholeNoProjectionTest) +// Broadcast to +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo1dTestFloat32, BroadcastTo1dTest<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo2dAxis0TestFloat32, BroadcastTo2dAxis0Test<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo2dAxis1TestFloat32, BroadcastTo2dAxis1Test<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis0TestFloat32, BroadcastTo3dAxis0Test<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis1TestFloat32, BroadcastTo3dAxis1Test<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis2TestFloat32, BroadcastTo3dAxis2Test<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo4dTestFloat32, BroadcastTo4dTest<DataType::Float32>) + +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo1dTestFloat16, BroadcastTo1dTest<DataType::Float16>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo2dAxis0TestFloat16, BroadcastTo2dAxis0Test<DataType::Float16>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo2dAxis1TestFloat16, BroadcastTo2dAxis1Test<DataType::Float16>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis0TestFloat16, BroadcastTo3dAxis0Test<DataType::Float16>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis1TestFloat16, BroadcastTo3dAxis1Test<DataType::Float16>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis2TestFloat16, BroadcastTo3dAxis2Test<DataType::Float16>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo4dTestFloat16, BroadcastTo4dTest<DataType::Float16>) + +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo1dTestQAsymmS8, BroadcastTo1dTest<DataType::QAsymmS8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo2dAxis0TestQAsymmS8, BroadcastTo2dAxis0Test<DataType::QAsymmS8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo2dAxis1TestQAsymmS8, BroadcastTo2dAxis1Test<DataType::QAsymmS8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis0TestQAsymmS8, BroadcastTo3dAxis0Test<DataType::QAsymmS8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis1TestQAsymmS8, BroadcastTo3dAxis1Test<DataType::QAsymmS8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis2TestQAsymmS8, BroadcastTo3dAxis2Test<DataType::QAsymmS8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo4dTestQAsymmS8, BroadcastTo4dTest<DataType::QAsymmS8>) + +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo1dTestQAsymmU8, BroadcastTo1dTest<DataType::QAsymmU8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo2dAxis0TestQAsymmU8, BroadcastTo2dAxis0Test<DataType::QAsymmU8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo2dAxis1TestQAsymmU8, BroadcastTo2dAxis1Test<DataType::QAsymmU8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis0TestQAsymmU8, BroadcastTo3dAxis0Test<DataType::QAsymmU8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis1TestQAsymmU8, BroadcastTo3dAxis1Test<DataType::QAsymmU8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis2TestQAsymmU8, BroadcastTo3dAxis2Test<DataType::QAsymmU8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo4dTestQAsymmU8, BroadcastTo4dTest<DataType::QAsymmU8>) + +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo1dTestQSymmS8, BroadcastTo1dTest<DataType::QSymmS8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo2dAxis0TestQSymmS8, BroadcastTo2dAxis0Test<DataType::QSymmS8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo2dAxis1TestQSymmS8, BroadcastTo2dAxis1Test<DataType::QSymmS8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis0TestQSymmS8, BroadcastTo3dAxis0Test<DataType::QSymmS8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis1TestQSymmS8, BroadcastTo3dAxis1Test<DataType::QSymmS8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis2TestQSymmS8, BroadcastTo3dAxis2Test<DataType::QSymmS8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo4dTestQSymmS8, BroadcastTo4dTest<DataType::QSymmS8>) + +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo1dTestQSymmS16, BroadcastTo1dTest<DataType::QSymmS16>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo2dAxis0TestQSymmS16, BroadcastTo2dAxis0Test<DataType::QSymmS16>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo2dAxis1TestQSymmS16, BroadcastTo2dAxis1Test<DataType::QSymmS16>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis0TestQSymmS16, BroadcastTo3dAxis0Test<DataType::QSymmS16>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis1TestQSymmS16, BroadcastTo3dAxis1Test<DataType::QSymmS16>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis2TestQSymmS16, BroadcastTo3dAxis2Test<DataType::QSymmS16>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo4dTestQSymmS16, BroadcastTo4dTest<DataType::QSymmS16>) + +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo1dTestSigned32, BroadcastTo1dTest<DataType::Signed32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo2dAxis0TestSigned32, BroadcastTo2dAxis0Test<DataType::Signed32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo2dAxis1TestSigned32, BroadcastTo2dAxis1Test<DataType::Signed32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis0TestSigned32, BroadcastTo3dAxis0Test<DataType::Signed32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis1TestSigned32, BroadcastTo3dAxis1Test<DataType::Signed32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis2TestSigned32, BroadcastTo3dAxis2Test<DataType::Signed32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo4dTestSigned32, BroadcastTo4dTest<DataType::Signed32>) }
\ No newline at end of file diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index 9372568133..42f92aec1d 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -79,6 +79,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefBatchNormalizationWorkload.hpp RefBatchToSpaceNdWorkload.cpp RefBatchToSpaceNdWorkload.hpp + RefBroadcastToWorkload.cpp + RefBroadcastToWorkload.hpp RefCastWorkload.cpp RefCastWorkload.hpp RefChannelShuffleWorkload.cpp diff --git a/src/backends/reference/workloads/RefBroadcastToWorkload.cpp b/src/backends/reference/workloads/RefBroadcastToWorkload.cpp new file mode 100644 index 0000000000..3a6184d22e --- /dev/null +++ b/src/backends/reference/workloads/RefBroadcastToWorkload.cpp @@ -0,0 +1,48 @@ +// +// Copyright © 2023 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefBroadcastToWorkload.hpp" +#include "RefWorkloadUtils.hpp" +#include "Profiling.hpp" +#include "Broadcast.hpp" + +#include "Decoders.hpp" +#include "Encoders.hpp" + +namespace armnn +{ + +RefBroadcastToWorkload::RefBroadcastToWorkload(const BroadcastToQueueDescriptor& descriptor, const WorkloadInfo& info) + : RefBaseWorkload(descriptor, info) +{} + +void RefBroadcastToWorkload::Execute() const +{ + Execute(m_Data.m_Inputs, m_Data.m_Outputs); +} + +void RefBroadcastToWorkload::ExecuteAsync(ExecutionData& executionData) +{ + WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data); + Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs); +} + +void RefBroadcastToWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const +{ + ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID("RefBroadcastToWorkload_Execute"); + const TensorInfo& inputInfo = GetTensorInfo(inputs[0]); + const TensorInfo& outputInfo = GetTensorInfo(outputs[0]); + + std::unique_ptr<Decoder<float>> input = MakeDecoder<float>(inputInfo, inputs[0]->Map()); + std::unique_ptr<Encoder<float>> output= MakeEncoder<float>(outputInfo, outputs[0]->Map()); + + auto broadcastTo = [](float x) + { + return x; + }; + BroadcastLoop(inputInfo.GetShape(), outputInfo.GetShape()).Unroll(broadcastTo, + 0, *input, *output); +} +} // namespace armnn diff --git a/src/backends/reference/workloads/RefBroadcastToWorkload.hpp b/src/backends/reference/workloads/RefBroadcastToWorkload.hpp new file mode 100644 index 0000000000..ac947ae787 --- /dev/null +++ b/src/backends/reference/workloads/RefBroadcastToWorkload.hpp @@ -0,0 +1,25 @@ +// +// Copyright © 2023 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "RefBaseWorkload.hpp" + +namespace armnn +{ +class RefBroadcastToWorkload : public RefBaseWorkload<BroadcastToQueueDescriptor> +{ + +public: + explicit RefBroadcastToWorkload(const BroadcastToQueueDescriptor& descriptor, + const WorkloadInfo& info); + + void Execute() const override; + void ExecuteAsync(ExecutionData& executionData) override; + +private: + void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const; +}; +} // namespace armnn
\ No newline at end of file diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index a36eae501c..98aa27b8a9 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -10,6 +10,7 @@ #include "RefBatchMatMulWorkload.hpp" #include "RefBatchNormalizationWorkload.hpp" #include "RefBatchToSpaceNdWorkload.hpp" +#include "RefBroadcastToWorkload.hpp" #include "RefCastWorkload.hpp" #include "RefChannelShuffleWorkload.hpp" #include "RefComparisonWorkload.hpp" |