aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFinn Williams <Finn.Williams@Arm.com>2020-06-11 17:35:44 +0100
committerFinn Williams <Finn.Williams@Arm.com>2020-07-06 19:07:52 +0100
commitfc884b4141a28fbd3c62f665341ec88158fcd332 (patch)
treedf18555baced3f4d917a78d24d87c5e2d1778f83
parent00b586b1f3106a4e6970ca7feacb1cc1892d107e (diff)
downloadandroid-nn-driver-fc884b4141a28fbd3c62f665341ec88158fcd332.tar.gz
Add support for Rank and scalar tensors
!armnn:3330 Signed-off-by: Finn Williams <Finn.Williams@Arm.com> Change-Id: Icc429d9fabb570193d12bffef0e00dda7b51032f
-rw-r--r--1.3/HalPolicy.cpp8
-rw-r--r--1.3/HalPolicy.hpp2
-rw-r--r--ConversionUtils.hpp6
-rw-r--r--ConversionUtils_1_3.hpp47
-rw-r--r--Utils.cpp16
-rw-r--r--Utils.hpp22
6 files changed, 88 insertions, 13 deletions
diff --git a/1.3/HalPolicy.cpp b/1.3/HalPolicy.cpp
index 1c4a1e36..79df1c7f 100644
--- a/1.3/HalPolicy.cpp
+++ b/1.3/HalPolicy.cpp
@@ -111,6 +111,8 @@ bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model,
return ConvertQuantizedLstm(operation, model, data);
case V1_3::OperationType::QUANTIZED_16BIT_LSTM:
return ConvertQuantized16BitLstm(operation, model, data);
+ case V1_3::OperationType::RANK:
+ return ConvertRank(operation, model, data);
case V1_3::OperationType::RELU:
return ConvertReLu(operation, model, data);
case V1_3::OperationType::RELU1:
@@ -394,6 +396,12 @@ bool HalPolicy::ConvertQuantized16BitLstm(const Operation& operation, const Mode
return ::ConvertQuantized16BitLstm<hal_1_3::HalPolicy>(operation, model, data);
}
+bool HalPolicy::ConvertRank(const Operation& operation, const Model& model, ConversionData& data)
+{
+ ALOGV("hal_1_3::HalPolicy::ConvertRank()");
+ return ::ConvertRank<hal_1_3::HalPolicy>(operation, model, data);
+}
+
bool HalPolicy::ConvertReLu(const Operation& operation, const Model& model, ConversionData& data)
{
ALOGV("hal_1_3::HalPolicy::ConvertReLu()");
diff --git a/1.3/HalPolicy.hpp b/1.3/HalPolicy.hpp
index 6df2ce2d..0eb5f4d7 100644
--- a/1.3/HalPolicy.hpp
+++ b/1.3/HalPolicy.hpp
@@ -123,6 +123,8 @@ private:
static bool ConvertQuantized16BitLstm(const Operation& operation, const Model& model, ConversionData& data);
+ static bool ConvertRank(const Operation& operation, const Model& model, ConversionData& data);
+
static bool ConvertReLu(const Operation& operation, const Model& model, ConversionData& data);
static bool ConvertReLu1(const Operation& operation, const Model& model, ConversionData& data);
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index 0fbd4e4d..5dc9993d 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -1451,14 +1451,16 @@ bool SetupAndTrackLayerOutputSlot(const HalOperation& operation,
uint32_t outputIndex,
armnn::IConnectableLayer& layer,
const HalModel& model,
- ConversionData& data)
+ ConversionData& data,
+ const armnn::TensorInfo* overrideOutputInfo = nullptr)
{
return SetupAndTrackLayerOutputSlot<HalPolicy>(operation,
outputIndex,
layer,
outputIndex,
model,
- data);
+ data,
+ overrideOutputInfo);
}
template<typename HalPolicy,
diff --git a/ConversionUtils_1_3.hpp b/ConversionUtils_1_3.hpp
index 3acb49a7..d5d89df1 100644
--- a/ConversionUtils_1_3.hpp
+++ b/ConversionUtils_1_3.hpp
@@ -642,4 +642,49 @@ bool ConvertQuantizedLstm(const HalOperation& operation, const HalModel& model,
SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 2, *layer, 2, model, data, &constOutputInfo));
}
-} // armnn_driver namespace \ No newline at end of file
+template<typename HalPolicy,
+ typename HalOperation = typename HalPolicy::Operation,
+ typename HalModel = typename HalPolicy::Model>
+bool ConvertRank(const HalOperation& operation, const HalModel& model, ConversionData& data)
+{
+ using HalOperand = typename HalPolicy::Operand;
+
+ const HalOperand* inputOperand = GetInputOperand<HalPolicy>(operation, 0, model);
+ const HalOperand* outputOperand = GetOutputOperand<HalPolicy>(operation, 0, model);
+
+ if (inputOperand == nullptr || outputOperand == nullptr)
+ {
+ return Fail("%s: Operation has invalid inputs", __func__);
+ }
+
+ const Shape inputOperandShape = GetOperandShape(*inputOperand);
+ const Shape outputOperandShape = GetOperandShape(*outputOperand);
+
+ LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
+ if (!input.IsValid())
+ {
+ return Fail("%s: Could not read input 0", __func__);
+ }
+
+ armnn::TensorInfo outInfo = GetTensorInfoForOperand(*outputOperand);
+
+ bool isSupported = false;
+ FORWARD_LAYER_SUPPORT_FUNC(__func__,
+ IsRankSupported,
+ data.m_Backends,
+ isSupported,
+ input.GetTensorInfo(),
+ outInfo);
+ if (!isSupported)
+ {
+ return false;
+ }
+
+ armnn::IConnectableLayer* layer = data.m_Network->AddRankLayer();
+ assert(layer != nullptr);
+ input.Connect(layer->GetInputSlot(0));
+
+ return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data, &outInfo);
+}
+
+} // armnn_driver namespace
diff --git a/Utils.cpp b/Utils.cpp
index 6481c287..d94a9377 100644
--- a/Utils.cpp
+++ b/Utils.cpp
@@ -200,6 +200,9 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_3::Operand& operand)
case V1_3::OperandType::TENSOR_INT32:
type = armnn::DataType::Signed32;
break;
+ case V1_3::OperandType::INT32:
+ type = armnn::DataType::Signed32;
+ break;
case V1_3::OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
type = armnn::DataType::QAsymmS8;
break;
@@ -207,7 +210,17 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_3::Operand& operand)
throw UnsupportedOperand<V1_3::OperandType>(operand.type);
}
- TensorInfo ret(operand.dimensions.size(), operand.dimensions.data(), type);
+ TensorInfo ret;
+ // 0 dimensional tensors will be flagged as scalars
+ if ( operand.dimensions.size() != 0)
+ {
+ ret = TensorInfo(operand.dimensions.size(), operand.dimensions.data(), type);
+ }
+ else
+ {
+ ret = TensorInfo(TensorShape(armnn::Dimensionality::Scalar), type);
+ }
+
if (perChannel)
{
// ExtraParams is expected to be of type channelQuant
@@ -224,7 +237,6 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_3::Operand& operand)
ret.SetQuantizationScale(operand.scale);
ret.SetQuantizationOffset(operand.zeroPoint);
}
-
return ret;
}
diff --git a/Utils.hpp b/Utils.hpp
index b61ddb21..d58d2735 100644
--- a/Utils.hpp
+++ b/Utils.hpp
@@ -150,18 +150,24 @@ inline V1_2::OutputShape ComputeShape(const armnn::TensorInfo& info)
{
V1_2::OutputShape shape;
- android::hardware::hidl_vec<uint32_t> dimensions;
-
armnn::TensorShape tensorShape = info.GetShape();
- const unsigned int numDims = tensorShape.GetNumDimensions();
- dimensions.resize(numDims);
-
- for (unsigned int outputIdx = 0u; outputIdx < numDims; ++outputIdx)
+ // Android will expect scalars as a zero dimensional tensor
+ if(tensorShape.GetDimensionality() == armnn::Dimensionality::Scalar)
+ {
+ shape.dimensions = android::hardware::hidl_vec<uint32_t>{};
+ }
+ else
{
- dimensions[outputIdx] = tensorShape[outputIdx];
+ android::hardware::hidl_vec<uint32_t> dimensions;
+ const unsigned int numDims = tensorShape.GetNumDimensions();
+ dimensions.resize(numDims);
+ for (unsigned int outputIdx = 0u; outputIdx < numDims; ++outputIdx)
+ {
+ dimensions[outputIdx] = tensorShape[outputIdx];
+ }
+ shape.dimensions = dimensions;
}
- shape.dimensions = dimensions;
shape.isSufficient = true;
return shape;