aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-10-29 14:25:09 +0000
committerÁron Virginás-Tar <aron.virginas-tar@arm.com>2019-10-30 10:39:18 +0000
commit3e0982b5f54cf2b90f319de417debf386cadcb30 (patch)
treec7419aa1b4d748c3b7340a9c80c7cfad14e047e7
parent7d2ccfd5a804c265dc18636d4dd9ea2df40c3403 (diff)
downloadandroid-nn-driver-3e0982b5f54cf2b90f319de417debf386cadcb30.tar.gz
IVGCVSW-3805 Add Comparison support to the android-nn-driver
* Added support for the following HAL1.2 operations: EQUAL, GREATER, GREATER_EQUAL, LESS, LESS_EQUAL and NOT_EQUAL Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> Change-Id: I71b68db70232da4aaad28caa7b0b5f9a1d7778d0
-rw-r--r--1.2/HalPolicy.cpp72
-rw-r--r--1.2/HalPolicy.hpp5
2 files changed, 76 insertions, 1 deletions
diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp
index 3e836d49..86b3bb78 100644
--- a/1.2/HalPolicy.cpp
+++ b/1.2/HalPolicy.cpp
@@ -14,6 +14,7 @@
#include <armnn/TypesUtils.hpp>
#include <cmath>
+#include <string>
namespace armnn_driver
{
@@ -46,12 +47,18 @@ bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model,
return ConvertDequantize(operation, model, data);
case V1_2::OperationType::DIV:
return ConvertDiv(operation, model, data);
+ case V1_2::OperationType::EQUAL:
+ return ConvertComparison(operation, model, data, ComparisonOperation::Equal);
case V1_2::OperationType::EXPAND_DIMS:
return ConvertExpandDims(operation, model, data);
case V1_2::OperationType::FLOOR:
return ConvertFloor(operation, model, data);
case V1_2::OperationType::FULLY_CONNECTED:
return ConvertFullyConnected(operation, model, data);
+ case V1_2::OperationType::GREATER:
+ return ConvertComparison(operation, model, data, ComparisonOperation::Greater);
+ case V1_2::OperationType::GREATER_EQUAL:
+ return ConvertComparison(operation, model, data, ComparisonOperation::GreaterOrEqual);
case V1_2::OperationType::GROUPED_CONV_2D:
return ConvertGroupedConv2d(operation, model, data);
case V1_2::OperationType::INSTANCE_NORMALIZATION:
@@ -60,6 +67,10 @@ bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model,
return ConvertL2Normalization(operation, model, data);
case V1_2::OperationType::L2_POOL_2D:
return ConvertL2Pool2d(operation, model, data);
+ case V1_2::OperationType::LESS:
+ return ConvertComparison(operation, model, data, ComparisonOperation::Less);
+ case V1_2::OperationType::LESS_EQUAL:
+ return ConvertComparison(operation, model, data, ComparisonOperation::LessOrEqual);
case V1_2::OperationType::LOCAL_RESPONSE_NORMALIZATION:
return ConvertLocalResponseNormalization(operation, model, data);
case V1_2::OperationType::LOGISTIC:
@@ -78,6 +89,8 @@ bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model,
return ConvertMinimum(operation, model, data);
case V1_2::OperationType::MUL:
return ConvertMul(operation, model, data);
+ case V1_2::OperationType::NOT_EQUAL:
+ return ConvertComparison(operation, model, data, ComparisonOperation::NotEqual);
case V1_2::OperationType::PAD:
return ConvertPad(operation, model, data);
case V1_2::OperationType::PAD_V2:
@@ -152,6 +165,63 @@ bool HalPolicy::ConvertBatchToSpaceNd(const Operation& operation, const Model& m
return ::ConvertBatchToSpaceNd<hal_1_2::HalPolicy>(operation, model, data);
}
+bool HalPolicy::ConvertComparison(const Operation& operation,
+ const Model& model,
+ ConversionData& data,
+ ComparisonOperation comparisonOperation)
+{
+ ALOGV("hal_1_2::HalPolicy::ConvertComparison()");
+ ALOGV("comparisonOperation = %s", GetComparisonOperationAsCString(comparisonOperation));
+
+ LayerInputHandle input0 = ConvertToLayerInputHandle<hal_1_2::HalPolicy>(operation, 0, model, data);
+ LayerInputHandle input1 = ConvertToLayerInputHandle<hal_1_2::HalPolicy>(operation, 1, model, data);
+
+ if (!(input0.IsValid() && input1.IsValid()))
+ {
+ return Fail("%s: Operation has invalid inputs", __func__);
+ }
+
+ const Operand* output = GetOutputOperand<hal_1_2::HalPolicy>(operation, 0, model);
+ if (!output)
+ {
+ return Fail("%s: Could not read output 0", __func__);
+ }
+
+ const TensorInfo& inputInfo0 = input0.GetTensorInfo();
+ const TensorInfo& inputInfo1 = input1.GetTensorInfo();
+ const TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
+
+ if (IsDynamicTensor(outputInfo))
+ {
+ return Fail("%s: Dynamic output tensors are not supported", __func__);
+ }
+
+ ComparisonDescriptor descriptor(comparisonOperation);
+
+ bool isSupported = false;
+ FORWARD_LAYER_SUPPORT_FUNC(__func__,
+ IsComparisonSupported,
+ data.m_Backends,
+ isSupported,
+ inputInfo0,
+ inputInfo1,
+ outputInfo,
+ descriptor);
+
+ if (!isSupported)
+ {
+ return false;
+ }
+
+ IConnectableLayer* layer = data.m_Network->AddComparisonLayer(descriptor);
+ assert(layer != nullptr);
+
+ input0.Connect(layer->GetInputSlot(0));
+ input1.Connect(layer->GetInputSlot(1));
+
+ return SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>(operation, 0, *layer, model, data);
+}
+
bool HalPolicy::ConvertConcatenation(const Operation& operation, const Model& model, ConversionData& data)
{
ALOGV("hal_1_2::HalPolicy::ConvertConcatenation()");
@@ -1073,7 +1143,7 @@ bool HalPolicy::ConvertLogSoftmax(const Operation& operation, const Model& model
return false;
}
- armnn::IConnectableLayer* layer = data.m_Network->AddLogSoftmaxLayer(descriptor);
+ IConnectableLayer* layer = data.m_Network->AddLogSoftmaxLayer(descriptor);
if (!layer)
{
return Fail("%s: AddLogSoftmaxLayer() returned nullptr", __func__);
diff --git a/1.2/HalPolicy.hpp b/1.2/HalPolicy.hpp
index 743ac11e..d611102b 100644
--- a/1.2/HalPolicy.hpp
+++ b/1.2/HalPolicy.hpp
@@ -39,6 +39,11 @@ private:
static bool ConvertBatchToSpaceNd(const Operation& operation, const Model& model, ConversionData& data);
+ static bool ConvertComparison(const Operation& operation,
+ const Model& model,
+ ConversionData& data,
+ armnn::ComparisonOperation comparisonOperation);
+
static bool ConvertConcatenation(const Operation& operation, const Model& model, ConversionData& data);
static bool ConvertConv2d(const Operation& operation, const Model& model, ConversionData& data);