aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference
diff options
context:
space:
mode:
authorIdriss Chaouch <idriss.chaouch@arm.com>2023-08-28 14:28:31 +0100
committerIdriss Chaouch <idriss.chaouch@arm.com>2023-08-31 11:26:28 +0100
commit98e383eadf4e670d057ad725c7fe7924fea8e36b (patch)
tree35acac15aa69ab405887289cb9674d388f06f96b /src/backends/reference
parent2be039bce38a4fa436e8310dfe14ebfff20d57bd (diff)
downloadarmnn-98e383eadf4e670d057ad725c7fe7924fea8e36b.tar.gz
IVGCVSW-7525 Add broadcast_to operator
Signed-off-by: Idriss Chaouch <idriss.chaouch@arm.com> Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: I94ec5f9120b2d736fdf98d00ec5137a4efd739b8
Diffstat (limited to 'src/backends/reference')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp53
-rw-r--r--src/backends/reference/RefLayerSupport.hpp5
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp5
-rw-r--r--src/backends/reference/backend.mk1
-rw-r--r--src/backends/reference/test/RefEndToEndTests.cpp12
-rw-r--r--src/backends/reference/test/RefLayerTests.cpp56
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt2
-rw-r--r--src/backends/reference/workloads/RefBroadcastToWorkload.cpp48
-rw-r--r--src/backends/reference/workloads/RefBroadcastToWorkload.hpp25
-rw-r--r--src/backends/reference/workloads/RefWorkloads.hpp1
10 files changed, 199 insertions, 9 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 0b1b9c7824..defdf0d807 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -100,6 +100,11 @@ bool RefLayerSupport::IsLayerSupported(const LayerType& type,
infos[1],
*(PolymorphicDowncast<const BatchToSpaceNdDescriptor*>(&descriptor)),
reasonIfUnsupported);
+ case LayerType::BroadcastTo:
+ return IsBroadcastToSupported(infos[0],
+ infos[1],
+ *(PolymorphicDowncast<const BroadcastToDescriptor*>(&descriptor)),
+ reasonIfUnsupported);
case LayerType::Comparison:
return IsComparisonSupported(infos[0],
infos[1],
@@ -807,20 +812,50 @@ bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
return supported;
}
+bool RefLayerSupport::IsBroadcastToSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const BroadcastToDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ IgnoreUnused(descriptor);
+
+ bool supported = true;
+
+ std::array<DataType, 8> supportedTypes
+ {
+ DataType::Float32,
+ DataType::Float16,
+ DataType::QAsymmS8,
+ DataType::QAsymmU8,
+ DataType::QSymmS8,
+ DataType::QSymmS16,
+ DataType::Signed32,
+ DataType::Signed64
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
+ "BroadcastTo: input type not supported.");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "BroadcastTo: output type not supported");
+
+ return supported;
+}
+
bool RefLayerSupport::IsCastSupported(const TensorInfo& input,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
std::array<DataType, 9> supportedInputTypes =
- {
- DataType::Float32,
- DataType::Float16,
- DataType::QSymmS8,
- DataType::QAsymmS8,
- DataType::QAsymmU8,
- DataType::QSymmS16,
- DataType::Signed32
- };
+ {
+ DataType::Float32,
+ DataType::Float16,
+ DataType::QSymmS8,
+ DataType::QAsymmS8,
+ DataType::QAsymmU8,
+ DataType::QSymmS16,
+ DataType::Signed32
+ };
bool supported = true;
supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp
index 377afac62f..9e7175389a 100644
--- a/src/backends/reference/RefLayerSupport.hpp
+++ b/src/backends/reference/RefLayerSupport.hpp
@@ -54,6 +54,11 @@ public:
const BatchToSpaceNdDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
+ bool IsBroadcastToSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const BroadcastToDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
+
bool IsCastSupported(const TensorInfo& input,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index fa2082d4f2..c4d9583a66 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -179,6 +179,11 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateWorkload(LayerType type,
= PolymorphicDowncast<const BatchToSpaceNdQueueDescriptor*>(&descriptor);
return std::make_unique<RefBatchToSpaceNdWorkload>(*batchToSpaceNdQueueDescriptor, info);
}
+ case LayerType::BroadcastTo:
+ {
+ auto broadcastToQueueDescriptor = PolymorphicDowncast<const BroadcastToQueueDescriptor*>(&descriptor);
+ return std::make_unique<RefBroadcastToWorkload>(*broadcastToQueueDescriptor, info);
+ }
case LayerType::Cast :
{
auto castQueueDescriptor = PolymorphicDowncast<const CastQueueDescriptor*>(&descriptor);
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index 7f047af930..27ca8f607a 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -53,6 +53,7 @@ BACKEND_SOURCES := \
workloads/RefBatchMatMulWorkload.cpp \
workloads/RefBatchNormalizationWorkload.cpp \
workloads/RefBatchToSpaceNdWorkload.cpp \
+ workloads/RefBroadcastToWorkload.cpp \
workloads/RefCastWorkload.cpp \
workloads/RefChannelShuffleWorkload.cpp \
workloads/RefComparisonWorkload.cpp \
diff --git a/src/backends/reference/test/RefEndToEndTests.cpp b/src/backends/reference/test/RefEndToEndTests.cpp
index 09d6ac5d20..e503d3fb7f 100644
--- a/src/backends/reference/test/RefEndToEndTests.cpp
+++ b/src/backends/reference/test/RefEndToEndTests.cpp
@@ -10,6 +10,7 @@
#include <backendsCommon/test/ArgMinMaxEndToEndTestImpl.hpp>
#include <backendsCommon/test/BatchToSpaceNdEndToEndTestImpl.hpp>
#include <backendsCommon/test/BatchMatMulEndToEndTestImpl.hpp>
+#include <backendsCommon/test/BroadcastToEndToEndTestImpl.hpp>
#include <backendsCommon/test/ChannelShuffleEndToEndTestImpl.hpp>
#include <backendsCommon/test/ComparisonEndToEndTestImpl.hpp>
#include <backendsCommon/test/ConcatEndToEndTestImpl.hpp>
@@ -1728,4 +1729,15 @@ TEST_CASE("RefReshapeRemovalNCHWSecondEndToEnd")
{
ReshapeRemovalNCHWEndToEnd<armnn::DataType::Float32>(defaultBackends, true, false);
}
+
+// BroadcastTo
+TEST_CASE("RefBroadcastToEndToEndFloat32")
+{
+ BroadcastToEndToEnd<armnn::DataType::Float32>(defaultBackends);
+}
+
+TEST_CASE("RefBroadcastToEndToEndWithElementWiseBinaryFloat32")
+{
+ BroadcastToEndToEndElementWiseBinary<armnn::DataType::Float32>(defaultBackends);
+}
}
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index a079bb712a..af4ed966b2 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -2823,4 +2823,60 @@ ARMNN_AUTO_TEST_CASE_WITH_THF(UnidirectionalSequenceLstmLayerInt8NoCifgWithPeeph
ARMNN_AUTO_TEST_CASE_WITH_THF(UnidirectionalSequenceLstmInt8WithCifgWithPeepholeNoProjection,
UnidirectionalSequenceLstmInt8WithCifgWithPeepholeNoProjectionTest)
+// Broadcast to
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo1dTestFloat32, BroadcastTo1dTest<DataType::Float32>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo2dAxis0TestFloat32, BroadcastTo2dAxis0Test<DataType::Float32>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo2dAxis1TestFloat32, BroadcastTo2dAxis1Test<DataType::Float32>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis0TestFloat32, BroadcastTo3dAxis0Test<DataType::Float32>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis1TestFloat32, BroadcastTo3dAxis1Test<DataType::Float32>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis2TestFloat32, BroadcastTo3dAxis2Test<DataType::Float32>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo4dTestFloat32, BroadcastTo4dTest<DataType::Float32>)
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo1dTestFloat16, BroadcastTo1dTest<DataType::Float16>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo2dAxis0TestFloat16, BroadcastTo2dAxis0Test<DataType::Float16>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo2dAxis1TestFloat16, BroadcastTo2dAxis1Test<DataType::Float16>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis0TestFloat16, BroadcastTo3dAxis0Test<DataType::Float16>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis1TestFloat16, BroadcastTo3dAxis1Test<DataType::Float16>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis2TestFloat16, BroadcastTo3dAxis2Test<DataType::Float16>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo4dTestFloat16, BroadcastTo4dTest<DataType::Float16>)
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo1dTestQAsymmS8, BroadcastTo1dTest<DataType::QAsymmS8>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo2dAxis0TestQAsymmS8, BroadcastTo2dAxis0Test<DataType::QAsymmS8>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo2dAxis1TestQAsymmS8, BroadcastTo2dAxis1Test<DataType::QAsymmS8>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis0TestQAsymmS8, BroadcastTo3dAxis0Test<DataType::QAsymmS8>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis1TestQAsymmS8, BroadcastTo3dAxis1Test<DataType::QAsymmS8>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis2TestQAsymmS8, BroadcastTo3dAxis2Test<DataType::QAsymmS8>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo4dTestQAsymmS8, BroadcastTo4dTest<DataType::QAsymmS8>)
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo1dTestQAsymmU8, BroadcastTo1dTest<DataType::QAsymmU8>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo2dAxis0TestQAsymmU8, BroadcastTo2dAxis0Test<DataType::QAsymmU8>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo2dAxis1TestQAsymmU8, BroadcastTo2dAxis1Test<DataType::QAsymmU8>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis0TestQAsymmU8, BroadcastTo3dAxis0Test<DataType::QAsymmU8>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis1TestQAsymmU8, BroadcastTo3dAxis1Test<DataType::QAsymmU8>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis2TestQAsymmU8, BroadcastTo3dAxis2Test<DataType::QAsymmU8>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo4dTestQAsymmU8, BroadcastTo4dTest<DataType::QAsymmU8>)
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo1dTestQSymmS8, BroadcastTo1dTest<DataType::QSymmS8>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo2dAxis0TestQSymmS8, BroadcastTo2dAxis0Test<DataType::QSymmS8>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo2dAxis1TestQSymmS8, BroadcastTo2dAxis1Test<DataType::QSymmS8>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis0TestQSymmS8, BroadcastTo3dAxis0Test<DataType::QSymmS8>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis1TestQSymmS8, BroadcastTo3dAxis1Test<DataType::QSymmS8>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis2TestQSymmS8, BroadcastTo3dAxis2Test<DataType::QSymmS8>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo4dTestQSymmS8, BroadcastTo4dTest<DataType::QSymmS8>)
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo1dTestQSymmS16, BroadcastTo1dTest<DataType::QSymmS16>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo2dAxis0TestQSymmS16, BroadcastTo2dAxis0Test<DataType::QSymmS16>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo2dAxis1TestQSymmS16, BroadcastTo2dAxis1Test<DataType::QSymmS16>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis0TestQSymmS16, BroadcastTo3dAxis0Test<DataType::QSymmS16>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis1TestQSymmS16, BroadcastTo3dAxis1Test<DataType::QSymmS16>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis2TestQSymmS16, BroadcastTo3dAxis2Test<DataType::QSymmS16>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo4dTestQSymmS16, BroadcastTo4dTest<DataType::QSymmS16>)
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo1dTestSigned32, BroadcastTo1dTest<DataType::Signed32>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo2dAxis0TestSigned32, BroadcastTo2dAxis0Test<DataType::Signed32>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo2dAxis1TestSigned32, BroadcastTo2dAxis1Test<DataType::Signed32>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis0TestSigned32, BroadcastTo3dAxis0Test<DataType::Signed32>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis1TestSigned32, BroadcastTo3dAxis1Test<DataType::Signed32>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis2TestSigned32, BroadcastTo3dAxis2Test<DataType::Signed32>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo4dTestSigned32, BroadcastTo4dTest<DataType::Signed32>)
} \ No newline at end of file
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index 9372568133..42f92aec1d 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -79,6 +79,8 @@ list(APPEND armnnRefBackendWorkloads_sources
RefBatchNormalizationWorkload.hpp
RefBatchToSpaceNdWorkload.cpp
RefBatchToSpaceNdWorkload.hpp
+ RefBroadcastToWorkload.cpp
+ RefBroadcastToWorkload.hpp
RefCastWorkload.cpp
RefCastWorkload.hpp
RefChannelShuffleWorkload.cpp
diff --git a/src/backends/reference/workloads/RefBroadcastToWorkload.cpp b/src/backends/reference/workloads/RefBroadcastToWorkload.cpp
new file mode 100644
index 0000000000..3a6184d22e
--- /dev/null
+++ b/src/backends/reference/workloads/RefBroadcastToWorkload.cpp
@@ -0,0 +1,48 @@
+//
+// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefBroadcastToWorkload.hpp"
+#include "RefWorkloadUtils.hpp"
+#include "Profiling.hpp"
+#include "Broadcast.hpp"
+
+#include "Decoders.hpp"
+#include "Encoders.hpp"
+
+namespace armnn
+{
+
+RefBroadcastToWorkload::RefBroadcastToWorkload(const BroadcastToQueueDescriptor& descriptor, const WorkloadInfo& info)
+ : RefBaseWorkload(descriptor, info)
+{}
+
+void RefBroadcastToWorkload::Execute() const
+{
+ Execute(m_Data.m_Inputs, m_Data.m_Outputs);
+}
+
+void RefBroadcastToWorkload::ExecuteAsync(ExecutionData& executionData)
+{
+ WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
+ Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
+}
+
+void RefBroadcastToWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
+{
+ ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID("RefBroadcastToWorkload_Execute");
+ const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
+ const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
+
+ std::unique_ptr<Decoder<float>> input = MakeDecoder<float>(inputInfo, inputs[0]->Map());
+ std::unique_ptr<Encoder<float>> output= MakeEncoder<float>(outputInfo, outputs[0]->Map());
+
+ auto broadcastTo = [](float x)
+ {
+ return x;
+ };
+ BroadcastLoop(inputInfo.GetShape(), outputInfo.GetShape()).Unroll(broadcastTo,
+ 0, *input, *output);
+}
+} // namespace armnn
diff --git a/src/backends/reference/workloads/RefBroadcastToWorkload.hpp b/src/backends/reference/workloads/RefBroadcastToWorkload.hpp
new file mode 100644
index 0000000000..ac947ae787
--- /dev/null
+++ b/src/backends/reference/workloads/RefBroadcastToWorkload.hpp
@@ -0,0 +1,25 @@
+//
+// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "RefBaseWorkload.hpp"
+
+namespace armnn
+{
+class RefBroadcastToWorkload : public RefBaseWorkload<BroadcastToQueueDescriptor>
+{
+
+public:
+ explicit RefBroadcastToWorkload(const BroadcastToQueueDescriptor& descriptor,
+ const WorkloadInfo& info);
+
+ void Execute() const override;
+ void ExecuteAsync(ExecutionData& executionData) override;
+
+private:
+ void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const;
+};
+} // namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index a36eae501c..98aa27b8a9 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -10,6 +10,7 @@
#include "RefBatchMatMulWorkload.hpp"
#include "RefBatchNormalizationWorkload.hpp"
#include "RefBatchToSpaceNdWorkload.hpp"
+#include "RefBroadcastToWorkload.hpp"
#include "RefCastWorkload.hpp"
#include "RefChannelShuffleWorkload.hpp"
#include "RefComparisonWorkload.hpp"