aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2020-03-19 13:54:04 +0000
committerSadik Armagan <sadik.armagan@arm.com>2020-03-19 13:54:04 +0000
commit793a70c1a5b8021137705d312916ac14ca35509a (patch)
tree44a497b8373fea525add6ce7fa2c20c23092cc49
parent4de83c5a6a57d0468d9f2f854c94bc4a760b66b6 (diff)
downloadandroid-nn-driver-793a70c1a5b8021137705d312916ac14ca35509a.tar.gz
IVGCVSW-4565 TENSOR_BOOL8 data type not supported in AndroidNN Driver
* Added TENSOR_BOOL8 support * Added Broadcast support to comparision operators !armnn:2903 Signed-off-by: Sadik Armagan <sadik.armagan@arm.com> Change-Id: I844e32b57399eff2dc60af9b2099145316c80cae
-rw-r--r--1.2/HalPolicy.cpp8
-rw-r--r--ConversionUtils.hpp1
-rw-r--r--Utils.cpp3
3 files changed, 9 insertions, 3 deletions
diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp
index b3ccc47f..1811688f 100644
--- a/1.2/HalPolicy.cpp
+++ b/1.2/HalPolicy.cpp
@@ -279,9 +279,11 @@ bool HalPolicy::ConvertComparison(const Operation& operation,
IConnectableLayer* layer = data.m_Network->AddComparisonLayer(descriptor);
assert(layer != nullptr);
-
- input0.Connect(layer->GetInputSlot(0));
- input1.Connect(layer->GetInputSlot(1));
+ bool isReshapeSupported = BroadcastTensor(input0, input1, layer, data);
+ if (!isReshapeSupported)
+ {
+ return false;
+ }
return SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>(operation, 0, *layer, model, data);
}
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index d4ca4345..90b1c7de 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -189,6 +189,7 @@ inline bool IsOperandTypeSupportedForTensors(V1_0::OperandType type)
inline bool IsOperandTypeSupportedForTensors(V1_2::OperandType type)
{
return type == V1_2::OperandType::BOOL ||
+ type == V1_2::OperandType::TENSOR_BOOL8 ||
type == V1_2::OperandType::TENSOR_FLOAT16 ||
type == V1_2::OperandType::TENSOR_FLOAT32 ||
type == V1_2::OperandType::TENSOR_QUANT8_ASYMM ||
diff --git a/Utils.cpp b/Utils.cpp
index c95f6e12..c548f849 100644
--- a/Utils.cpp
+++ b/Utils.cpp
@@ -113,6 +113,9 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_2::Operand& operand)
DataType type;
switch (operand.type)
{
+ case V1_2::OperandType::TENSOR_BOOL8:
+ type = armnn::DataType::Boolean;
+ break;
case V1_2::OperandType::TENSOR_FLOAT32:
type = armnn::DataType::Float32;
break;