aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/RefLayerSupport.cpp
diff options
context:
space:
mode:
authorIdriss Chaouch <idriss.chaouch@arm.com>2023-08-28 14:28:31 +0100
committerIdriss Chaouch <idriss.chaouch@arm.com>2023-08-31 11:26:28 +0100
commit98e383eadf4e670d057ad725c7fe7924fea8e36b (patch)
tree35acac15aa69ab405887289cb9674d388f06f96b /src/backends/reference/RefLayerSupport.cpp
parent2be039bce38a4fa436e8310dfe14ebfff20d57bd (diff)
downloadarmnn-98e383eadf4e670d057ad725c7fe7924fea8e36b.tar.gz
IVGCVSW-7525 Add broadcast_to operator
Signed-off-by: Idriss Chaouch <idriss.chaouch@arm.com> Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: I94ec5f9120b2d736fdf98d00ec5137a4efd739b8
Diffstat (limited to 'src/backends/reference/RefLayerSupport.cpp')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp53
1 files changed, 44 insertions, 9 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 0b1b9c7824..defdf0d807 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -100,6 +100,11 @@ bool RefLayerSupport::IsLayerSupported(const LayerType& type,
infos[1],
*(PolymorphicDowncast<const BatchToSpaceNdDescriptor*>(&descriptor)),
reasonIfUnsupported);
+ case LayerType::BroadcastTo:
+ return IsBroadcastToSupported(infos[0],
+ infos[1],
+ *(PolymorphicDowncast<const BroadcastToDescriptor*>(&descriptor)),
+ reasonIfUnsupported);
case LayerType::Comparison:
return IsComparisonSupported(infos[0],
infos[1],
@@ -807,20 +812,50 @@ bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
return supported;
}
+bool RefLayerSupport::IsBroadcastToSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const BroadcastToDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ IgnoreUnused(descriptor);
+
+ bool supported = true;
+
+ std::array<DataType, 8> supportedTypes
+ {
+ DataType::Float32,
+ DataType::Float16,
+ DataType::QAsymmS8,
+ DataType::QAsymmU8,
+ DataType::QSymmS8,
+ DataType::QSymmS16,
+ DataType::Signed32,
+ DataType::Signed64
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
+ "BroadcastTo: input type not supported.");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "BroadcastTo: output type not supported");
+
+ return supported;
+}
+
bool RefLayerSupport::IsCastSupported(const TensorInfo& input,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
std::array<DataType, 9> supportedInputTypes =
- {
- DataType::Float32,
- DataType::Float16,
- DataType::QSymmS8,
- DataType::QAsymmS8,
- DataType::QAsymmU8,
- DataType::QSymmS16,
- DataType::Signed32
- };
+ {
+ DataType::Float32,
+ DataType::Float16,
+ DataType::QSymmS8,
+ DataType::QAsymmS8,
+ DataType::QAsymmU8,
+ DataType::QSymmS16,
+ DataType::Signed32
+ };
bool supported = true;
supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,