aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2019-09-12 16:26:29 +0100
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2019-09-16 11:04:54 +0100
commit85f9654dd951d247e4f6673603bf9cf00c299712 (patch)
tree7bb23f8ae3537bc20e78d634e12093d978cddf1f
parenta234ab1c19ef92bfb76e9c55f002b81e6e1cd206 (diff)
downloadandroid-nn-driver-85f9654dd951d247e4f6673603bf9cf00c299712.tar.gz
IVGCVSW-3663 Add EXPAND_DIMS to the android-nn-driver
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: Ibf6c53822f728c0e15a9ca1cd2c2ad3593edbd82
-rw-r--r--1.2/HalPolicy.cpp71
-rw-r--r--1.2/HalPolicy.hpp2
-rw-r--r--NnapiSupport.txt1
3 files changed, 74 insertions, 0 deletions
diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp
index 1de57e5a..7aa6967a 100644
--- a/1.2/HalPolicy.cpp
+++ b/1.2/HalPolicy.cpp
@@ -9,6 +9,7 @@
#include <DataLayoutIndexed.hpp>
#include <Half.hpp>
+#include <TensorUtils.hpp>
#include <cmath>
@@ -39,6 +40,8 @@ bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model,
return ConvertDequantize(operation, model, data);
case V1_2::OperationType::DIV:
return ConvertDiv(operation, model, data);
+ case V1_2::OperationType::EXPAND_DIMS:
+ return ConvertExpandDims(operation, model, data);
case V1_2::OperationType::FLOOR:
return ConvertFloor(operation, model, data);
case V1_2::OperationType::FULLY_CONNECTED:
@@ -473,6 +476,74 @@ bool HalPolicy::ConvertDiv(const Operation& operation, const Model& model, Conve
return ::ConvertDiv<hal_1_2::HalPolicy>(operation, model, data);
}
+bool HalPolicy::ConvertExpandDims(const Operation& operation, const Model& model, ConversionData& data)
+{
+ ALOGV("hal_1_2::HalPolicy::ConvertExpandDims()");
+
+ LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
+
+ if (!input.IsValid())
+ {
+ return Fail("%s: Operation has invalid input", __func__);
+ }
+
+ const Operand* output = GetOutputOperand<HalPolicy>(operation, 0, model);
+ if (!output)
+ {
+ return Fail("%s: Operation has invalid output", __func__);
+ }
+
+ const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
+ if (IsDynamicTensor(outputInfo))
+ {
+ return Fail("%s: Dynamic output tensors are not supported", __func__);
+ }
+
+ int32_t axis;
+ if (!GetInputScalar<HalPolicy>(operation, 1, OperandType::INT32, axis, model, data))
+ {
+ return Fail("%s: failed to get axis input value", __func__);
+ }
+
+ armnn::TensorShape targetShape;
+
+ try
+ {
+ targetShape = armnnUtils::ExpandDims(input.GetTensorInfo().GetShape(), axis);
+ }
+ catch (const std::exception &e)
+ {
+ return Fail("%s: %s", __func__, e.what());
+ }
+
+ if (targetShape != outputInfo.GetShape())
+ {
+ return Fail("%s: Shape of the output operand does not match the resolved expanded shape", __func__);
+ }
+
+ armnn::ReshapeDescriptor reshapeDescriptor;
+ reshapeDescriptor.m_TargetShape = targetShape;
+
+ bool isSupported = false;
+ FORWARD_LAYER_SUPPORT_FUNC(__func__,
+ IsReshapeSupported,
+ data.m_Backends,
+ isSupported,
+ input.GetTensorInfo(),
+ reshapeDescriptor);
+
+ if (!isSupported)
+ {
+ return false;
+ }
+
+ armnn::IConnectableLayer* layer = data.m_Network->AddReshapeLayer(reshapeDescriptor);
+ assert(layer != nullptr);
+ input.Connect(layer->GetInputSlot(0));
+
+ return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data);
+}
+
bool HalPolicy::ConvertFloor(const Operation& operation, const Model& model, ConversionData& data)
{
ALOGV("hal_1_2::HalPolicy::ConvertFloor()");
diff --git a/1.2/HalPolicy.hpp b/1.2/HalPolicy.hpp
index e3d39702..c7e1d4bc 100644
--- a/1.2/HalPolicy.hpp
+++ b/1.2/HalPolicy.hpp
@@ -49,6 +49,8 @@ private:
static bool ConvertDiv(const Operation& operation, const Model& model, ConversionData& data);
+ static bool ConvertExpandDims(const Operation& operation, const Model& model, ConversionData& data);
+
static bool ConvertFloor(const Operation& operation, const Model& model, ConversionData& data);
static bool ConvertFullyConnected(const Operation& operation, const Model& model, ConversionData& data);
diff --git a/NnapiSupport.txt b/NnapiSupport.txt
index 651de0f7..3dea13e3 100644
--- a/NnapiSupport.txt
+++ b/NnapiSupport.txt
@@ -22,6 +22,7 @@ CONV_2D (FLOAT32,QUANT8_ASYMM)
DEPTHWISE_CONV_2D (FLOAT32,QUANT8_ASYMM)
DIV (FLOAT32,QUANT8_ASYMM)
DEQUANTIZE (FLOAT32,QUANT8_ASYMM)
+EXPAND_DIMS (FLOAT32,QUANT8_ASYMM)
FLOOR (FLOAT32)
FULLY_CONNECTED (FLOAT32,QUANT8_ASYMM)
L2_NORMALIZATION (FLOAT32)