From 85f9654dd951d247e4f6673603bf9cf00c299712 Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Thu, 12 Sep 2019 16:26:29 +0100 Subject: IVGCVSW-3663 Add EXPAND_DIMS to the android-nn-driver Signed-off-by: Narumol Prangnawarat Change-Id: Ibf6c53822f728c0e15a9ca1cd2c2ad3593edbd82 --- 1.2/HalPolicy.cpp | 71 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1.2/HalPolicy.hpp | 2 ++ NnapiSupport.txt | 1 + 3 files changed, 74 insertions(+) diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp index 1de57e5a..7aa6967a 100644 --- a/1.2/HalPolicy.cpp +++ b/1.2/HalPolicy.cpp @@ -9,6 +9,7 @@ #include #include +#include #include @@ -39,6 +40,8 @@ bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model, return ConvertDequantize(operation, model, data); case V1_2::OperationType::DIV: return ConvertDiv(operation, model, data); + case V1_2::OperationType::EXPAND_DIMS: + return ConvertExpandDims(operation, model, data); case V1_2::OperationType::FLOOR: return ConvertFloor(operation, model, data); case V1_2::OperationType::FULLY_CONNECTED: @@ -473,6 +476,74 @@ bool HalPolicy::ConvertDiv(const Operation& operation, const Model& model, Conve return ::ConvertDiv(operation, model, data); } +bool HalPolicy::ConvertExpandDims(const Operation& operation, const Model& model, ConversionData& data) +{ + ALOGV("hal_1_2::HalPolicy::ConvertExpandDims()"); + + LayerInputHandle input = ConvertToLayerInputHandle(operation, 0, model, data); + + if (!input.IsValid()) + { + return Fail("%s: Operation has invalid input", __func__); + } + + const Operand* output = GetOutputOperand(operation, 0, model); + if (!output) + { + return Fail("%s: Operation has invalid output", __func__); + } + + const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*output); + if (IsDynamicTensor(outputInfo)) + { + return Fail("%s: Dynamic output tensors are not supported", __func__); + } + + int32_t axis; + if (!GetInputScalar(operation, 1, OperandType::INT32, axis, model, data)) + { + return Fail("%s: failed to get axis input value", __func__); + } + + armnn::TensorShape targetShape; + + try + { + targetShape = armnnUtils::ExpandDims(input.GetTensorInfo().GetShape(), axis); + } + catch (const std::exception &e) + { + return Fail("%s: %s", __func__, e.what()); + } + + if (targetShape != outputInfo.GetShape()) + { + return Fail("%s: Shape of the output operand does not match the resolved expanded shape", __func__); + } + + armnn::ReshapeDescriptor reshapeDescriptor; + reshapeDescriptor.m_TargetShape = targetShape; + + bool isSupported = false; + FORWARD_LAYER_SUPPORT_FUNC(__func__, + IsReshapeSupported, + data.m_Backends, + isSupported, + input.GetTensorInfo(), + reshapeDescriptor); + + if (!isSupported) + { + return false; + } + + armnn::IConnectableLayer* layer = data.m_Network->AddReshapeLayer(reshapeDescriptor); + assert(layer != nullptr); + input.Connect(layer->GetInputSlot(0)); + + return SetupAndTrackLayerOutputSlot(operation, 0, *layer, model, data); +} + bool HalPolicy::ConvertFloor(const Operation& operation, const Model& model, ConversionData& data) { ALOGV("hal_1_2::HalPolicy::ConvertFloor()"); diff --git a/1.2/HalPolicy.hpp b/1.2/HalPolicy.hpp index e3d39702..c7e1d4bc 100644 --- a/1.2/HalPolicy.hpp +++ b/1.2/HalPolicy.hpp @@ -49,6 +49,8 @@ private: static bool ConvertDiv(const Operation& operation, const Model& model, ConversionData& data); + static bool ConvertExpandDims(const Operation& operation, const Model& model, ConversionData& data); + static bool ConvertFloor(const Operation& operation, const Model& model, ConversionData& data); static bool ConvertFullyConnected(const Operation& operation, const Model& model, ConversionData& data); diff --git a/NnapiSupport.txt b/NnapiSupport.txt index 651de0f7..3dea13e3 100644 --- a/NnapiSupport.txt +++ b/NnapiSupport.txt @@ -22,6 +22,7 @@ CONV_2D (FLOAT32,QUANT8_ASYMM) DEPTHWISE_CONV_2D (FLOAT32,QUANT8_ASYMM) DIV (FLOAT32,QUANT8_ASYMM) DEQUANTIZE (FLOAT32,QUANT8_ASYMM) +EXPAND_DIMS (FLOAT32,QUANT8_ASYMM) FLOOR (FLOAT32) FULLY_CONNECTED (FLOAT32,QUANT8_ASYMM) L2_NORMALIZATION (FLOAT32) -- cgit v1.2.1