diff options
author | Idriss Chaouch <idriss.chaouch@arm.com> | 2023-08-28 14:28:31 +0100 |
---|---|---|
committer | Idriss Chaouch <idriss.chaouch@arm.com> | 2023-08-31 11:26:28 +0100 |
commit | 98e383eadf4e670d057ad725c7fe7924fea8e36b (patch) | |
tree | 35acac15aa69ab405887289cb9674d388f06f96b /src/backends/reference/RefLayerSupport.cpp | |
parent | 2be039bce38a4fa436e8310dfe14ebfff20d57bd (diff) | |
download | armnn-98e383eadf4e670d057ad725c7fe7924fea8e36b.tar.gz |
IVGCVSW-7525 Add broadcast_to operator
Signed-off-by: Idriss Chaouch <idriss.chaouch@arm.com>
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: I94ec5f9120b2d736fdf98d00ec5137a4efd739b8
Diffstat (limited to 'src/backends/reference/RefLayerSupport.cpp')
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 53 |
1 files changed, 44 insertions, 9 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 0b1b9c7824..defdf0d807 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -100,6 +100,11 @@ bool RefLayerSupport::IsLayerSupported(const LayerType& type, infos[1], *(PolymorphicDowncast<const BatchToSpaceNdDescriptor*>(&descriptor)), reasonIfUnsupported); + case LayerType::BroadcastTo: + return IsBroadcastToSupported(infos[0], + infos[1], + *(PolymorphicDowncast<const BroadcastToDescriptor*>(&descriptor)), + reasonIfUnsupported); case LayerType::Comparison: return IsComparisonSupported(infos[0], infos[1], @@ -807,20 +812,50 @@ bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input, return supported; } +bool RefLayerSupport::IsBroadcastToSupported(const TensorInfo& input, + const TensorInfo& output, + const BroadcastToDescriptor& descriptor, + Optional<std::string&> reasonIfUnsupported) const +{ + IgnoreUnused(descriptor); + + bool supported = true; + + std::array<DataType, 8> supportedTypes + { + DataType::Float32, + DataType::Float16, + DataType::QAsymmS8, + DataType::QAsymmU8, + DataType::QSymmS8, + DataType::QSymmS16, + DataType::Signed32, + DataType::Signed64 + }; + + supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported, + "BroadcastTo: input type not supported."); + + supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported, + "BroadcastTo: output type not supported"); + + return supported; +} + bool RefLayerSupport::IsCastSupported(const TensorInfo& input, const TensorInfo& output, Optional<std::string&> reasonIfUnsupported) const { std::array<DataType, 9> supportedInputTypes = - { - DataType::Float32, - DataType::Float16, - DataType::QSymmS8, - DataType::QAsymmS8, - DataType::QAsymmU8, - DataType::QSymmS16, - DataType::Signed32 - }; + { + DataType::Float32, + DataType::Float16, + DataType::QSymmS8, + DataType::QAsymmS8, + DataType::QAsymmU8, + DataType::QSymmS16, + DataType::Signed32 + }; bool supported = true; supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported, |