aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/RefLayerSupport.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/RefLayerSupport.cpp')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp8
1 files changed, 7 insertions, 1 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 654aeb55dc..3e04a19df4 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -14,8 +14,8 @@
#include <LayerSupportCommon.hpp>
#include <backendsCommon/LayerSupportRules.hpp>
-#include <vector>
#include <array>
+#include <vector>
namespace armnn
{
@@ -940,6 +940,9 @@ bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
"Reference comparison: output is not of type Boolean");
+ supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
+ "Reference comparison: shapes are not suitable for implicit broadcast.");
+
return supported;
}
@@ -1751,6 +1754,9 @@ bool RefLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
"Reference LogicalBinary: input and output types do not match");
+ supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
+ "Reference LogicalBinary: shapes are not suitable for implicit broadcast.");
+
return supported;
}