aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2019-11-18 17:11:21 +0000
committerSadik Armagan <sadik.armagan@arm.com>2019-11-18 17:11:58 +0000
commitd08119479be236043205aab5b23d8a29bc3e8768 (patch)
tree50db354b9c7ac9199ee69598be630a5b97ccfaa3
parent9acf579752d7dbf43a26e933224854d2d003da30 (diff)
downloadandroid-nn-driver-d08119479be236043205aab5b23d8a29bc3e8768.tar.gz
IVGCVSW-4137 Failing LSTM android-nn-driver Unit Tests in HAL 1.2
* Fixed for failing LSTM unit tests Signed-off-by: Sadik Armagan <sadik.armagan@arm.com> Change-Id: I773c5227bc8d5606924cc0472c51172476773056
-rw-r--r--ConversionUtils.hpp36
1 files changed, 24 insertions, 12 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index dbdba4cd..0637c2b5 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -2300,15 +2300,22 @@ inline bool IsQSymm8(const V1_2::Operand& operand)
template<typename HalPolicy,
typename Operation = typename HalPolicy::Operation,
typename Model = typename HalPolicy::Model>
-std::tuple<std::unique_ptr<float[]>, size_t, armnn::TensorInfo>
+std::tuple<std::unique_ptr<float[]>, size_t, armnn::TensorInfo, int>
DequantizeIfRequired(size_t operand_index, const Operation& operation, const Model& model, const ConversionData& data)
{
using HalOperand = typename HalPolicy::Operand;
const HalOperand* weightsOperand = GetInputOperand<HalPolicy>(operation, operand_index, model);
- if (!weightsOperand || IsOperandConstant<HalPolicy>(*weightsOperand))
+ if (!weightsOperand)
{
- return { nullptr, 0, armnn::TensorInfo() };
+ // Invalid Operand will return with error code '-1'
+ return { nullptr, 0, armnn::TensorInfo(), -1 };
+ }
+
+ if (IsOperandConstant<HalPolicy>(*weightsOperand))
+ {
+ // Weights are already constant
+ return { nullptr, 0, armnn::TensorInfo(), 0 };
}
const size_t weightsInputIndex = operation.inputs[operand_index];
@@ -2369,10 +2376,10 @@ DequantizeIfRequired(size_t operand_index, const Operation& operation, const Mod
operand->dimensions.data(),
armnn::DataType::Float32);
- return { std::move(dequantizedBuffer), dequantizedBufferLength * sizeof(float), std::move(tensorInfo) };
+ return { std::move(dequantizedBuffer), dequantizedBufferLength * sizeof(float), std::move(tensorInfo), 0 };
}
- return { nullptr, 0, armnn::TensorInfo() };
+ return { nullptr, 0, armnn::TensorInfo() , 0};
}
template<typename HalPolicy,
@@ -2385,16 +2392,21 @@ ConstTensorPin DequantizeAndMakeConstTensorPin(const Operation& operation,
bool optional = false)
{
auto dequantized = DequantizeIfRequired<HalPolicy, Operation, Model>(operandIndex,operation, model, data);
- if (std::get<1>(dequantized) == 0 && optional)
+ if (std::get<3>(dequantized) == -1)
+ {
+ // Return it as invalid, tensor with no values is not really an error
+ return ConstTensorPin();
+ }
+
+ if (std::get<1>(dequantized) == 0)
{
- // Optional tensor with no values is not really an error. Return it as invalid, but marked as optional
- return ConstTensorPin(true);
+ return ConvertOperationInputToConstTensorPin<HalPolicy>(
+ operation, operandIndex, model, data, g_DontPermute, nullptr, optional);
+
}
- return std::get<1>(dequantized) ?
- ConstTensorPin(std::get<2>(dequantized), std::get<0>(dequantized).get(),
- std::get<1>(dequantized), g_DontPermute):
- ConvertOperationInputToConstTensorPin<HalPolicy>(operation, operandIndex, model, data);
+ return ConstTensorPin(std::get<2>(dequantized), std::get<0>(dequantized).get(),
+ std::get<1>(dequantized), g_DontPermute);
}