From 3e0982b5f54cf2b90f319de417debf386cadcb30 Mon Sep 17 00:00:00 2001 From: Aron Virginas-Tar Date: Tue, 29 Oct 2019 14:25:09 +0000 Subject: IVGCVSW-3805 Add Comparison support to the android-nn-driver * Added support for the following HAL1.2 operations: EQUAL, GREATER, GREATER_EQUAL, LESS, LESS_EQUAL and NOT_EQUAL Signed-off-by: Aron Virginas-Tar Change-Id: I71b68db70232da4aaad28caa7b0b5f9a1d7778d0 --- 1.2/HalPolicy.cpp | 72 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1.2/HalPolicy.hpp | 5 ++++ 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp index 3e836d49..86b3bb78 100644 --- a/1.2/HalPolicy.cpp +++ b/1.2/HalPolicy.cpp @@ -14,6 +14,7 @@ #include #include +#include namespace armnn_driver { @@ -46,12 +47,18 @@ bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model, return ConvertDequantize(operation, model, data); case V1_2::OperationType::DIV: return ConvertDiv(operation, model, data); + case V1_2::OperationType::EQUAL: + return ConvertComparison(operation, model, data, ComparisonOperation::Equal); case V1_2::OperationType::EXPAND_DIMS: return ConvertExpandDims(operation, model, data); case V1_2::OperationType::FLOOR: return ConvertFloor(operation, model, data); case V1_2::OperationType::FULLY_CONNECTED: return ConvertFullyConnected(operation, model, data); + case V1_2::OperationType::GREATER: + return ConvertComparison(operation, model, data, ComparisonOperation::Greater); + case V1_2::OperationType::GREATER_EQUAL: + return ConvertComparison(operation, model, data, ComparisonOperation::GreaterOrEqual); case V1_2::OperationType::GROUPED_CONV_2D: return ConvertGroupedConv2d(operation, model, data); case V1_2::OperationType::INSTANCE_NORMALIZATION: @@ -60,6 +67,10 @@ bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model, return ConvertL2Normalization(operation, model, data); case V1_2::OperationType::L2_POOL_2D: return ConvertL2Pool2d(operation, model, data); + case V1_2::OperationType::LESS: + return ConvertComparison(operation, model, data, ComparisonOperation::Less); + case V1_2::OperationType::LESS_EQUAL: + return ConvertComparison(operation, model, data, ComparisonOperation::LessOrEqual); case V1_2::OperationType::LOCAL_RESPONSE_NORMALIZATION: return ConvertLocalResponseNormalization(operation, model, data); case V1_2::OperationType::LOGISTIC: @@ -78,6 +89,8 @@ bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model, return ConvertMinimum(operation, model, data); case V1_2::OperationType::MUL: return ConvertMul(operation, model, data); + case V1_2::OperationType::NOT_EQUAL: + return ConvertComparison(operation, model, data, ComparisonOperation::NotEqual); case V1_2::OperationType::PAD: return ConvertPad(operation, model, data); case V1_2::OperationType::PAD_V2: @@ -152,6 +165,63 @@ bool HalPolicy::ConvertBatchToSpaceNd(const Operation& operation, const Model& m return ::ConvertBatchToSpaceNd(operation, model, data); } +bool HalPolicy::ConvertComparison(const Operation& operation, + const Model& model, + ConversionData& data, + ComparisonOperation comparisonOperation) +{ + ALOGV("hal_1_2::HalPolicy::ConvertComparison()"); + ALOGV("comparisonOperation = %s", GetComparisonOperationAsCString(comparisonOperation)); + + LayerInputHandle input0 = ConvertToLayerInputHandle(operation, 0, model, data); + LayerInputHandle input1 = ConvertToLayerInputHandle(operation, 1, model, data); + + if (!(input0.IsValid() && input1.IsValid())) + { + return Fail("%s: Operation has invalid inputs", __func__); + } + + const Operand* output = GetOutputOperand(operation, 0, model); + if (!output) + { + return Fail("%s: Could not read output 0", __func__); + } + + const TensorInfo& inputInfo0 = input0.GetTensorInfo(); + const TensorInfo& inputInfo1 = input1.GetTensorInfo(); + const TensorInfo& outputInfo = GetTensorInfoForOperand(*output); + + if (IsDynamicTensor(outputInfo)) + { + return Fail("%s: Dynamic output tensors are not supported", __func__); + } + + ComparisonDescriptor descriptor(comparisonOperation); + + bool isSupported = false; + FORWARD_LAYER_SUPPORT_FUNC(__func__, + IsComparisonSupported, + data.m_Backends, + isSupported, + inputInfo0, + inputInfo1, + outputInfo, + descriptor); + + if (!isSupported) + { + return false; + } + + IConnectableLayer* layer = data.m_Network->AddComparisonLayer(descriptor); + assert(layer != nullptr); + + input0.Connect(layer->GetInputSlot(0)); + input1.Connect(layer->GetInputSlot(1)); + + return SetupAndTrackLayerOutputSlot(operation, 0, *layer, model, data); +} + bool HalPolicy::ConvertConcatenation(const Operation& operation, const Model& model, ConversionData& data) { ALOGV("hal_1_2::HalPolicy::ConvertConcatenation()"); @@ -1073,7 +1143,7 @@ bool HalPolicy::ConvertLogSoftmax(const Operation& operation, const Model& model return false; } - armnn::IConnectableLayer* layer = data.m_Network->AddLogSoftmaxLayer(descriptor); + IConnectableLayer* layer = data.m_Network->AddLogSoftmaxLayer(descriptor); if (!layer) { return Fail("%s: AddLogSoftmaxLayer() returned nullptr", __func__); diff --git a/1.2/HalPolicy.hpp b/1.2/HalPolicy.hpp index 743ac11e..d611102b 100644 --- a/1.2/HalPolicy.hpp +++ b/1.2/HalPolicy.hpp @@ -39,6 +39,11 @@ private: static bool ConvertBatchToSpaceNd(const Operation& operation, const Model& model, ConversionData& data); + static bool ConvertComparison(const Operation& operation, + const Model& model, + ConversionData& data, + armnn::ComparisonOperation comparisonOperation); + static bool ConvertConcatenation(const Operation& operation, const Model& model, ConversionData& data); static bool ConvertConv2d(const Operation& operation, const Model& model, ConversionData& data); -- cgit v1.2.1