diff options
author | Sadik Armagan <sadik.armagan@arm.com> | 2019-11-18 17:11:21 +0000 |
---|---|---|
committer | Sadik Armagan <sadik.armagan@arm.com> | 2019-11-18 17:11:58 +0000 |
commit | d08119479be236043205aab5b23d8a29bc3e8768 (patch) | |
tree | 50db354b9c7ac9199ee69598be630a5b97ccfaa3 | |
parent | 9acf579752d7dbf43a26e933224854d2d003da30 (diff) | |
download | android-nn-driver-d08119479be236043205aab5b23d8a29bc3e8768.tar.gz |
IVGCVSW-4137 Failing LSTM android-nn-driver Unit Tests in HAL 1.2
* Fixed for failing LSTM unit tests
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Change-Id: I773c5227bc8d5606924cc0472c51172476773056
-rw-r--r-- | ConversionUtils.hpp | 36 |
1 files changed, 24 insertions, 12 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp index dbdba4cd..0637c2b5 100644 --- a/ConversionUtils.hpp +++ b/ConversionUtils.hpp @@ -2300,15 +2300,22 @@ inline bool IsQSymm8(const V1_2::Operand& operand) template<typename HalPolicy, typename Operation = typename HalPolicy::Operation, typename Model = typename HalPolicy::Model> -std::tuple<std::unique_ptr<float[]>, size_t, armnn::TensorInfo> +std::tuple<std::unique_ptr<float[]>, size_t, armnn::TensorInfo, int> DequantizeIfRequired(size_t operand_index, const Operation& operation, const Model& model, const ConversionData& data) { using HalOperand = typename HalPolicy::Operand; const HalOperand* weightsOperand = GetInputOperand<HalPolicy>(operation, operand_index, model); - if (!weightsOperand || IsOperandConstant<HalPolicy>(*weightsOperand)) + if (!weightsOperand) { - return { nullptr, 0, armnn::TensorInfo() }; + // Invalid Operand will return with error code '-1' + return { nullptr, 0, armnn::TensorInfo(), -1 }; + } + + if (IsOperandConstant<HalPolicy>(*weightsOperand)) + { + // Weights are already constant + return { nullptr, 0, armnn::TensorInfo(), 0 }; } const size_t weightsInputIndex = operation.inputs[operand_index]; @@ -2369,10 +2376,10 @@ DequantizeIfRequired(size_t operand_index, const Operation& operation, const Mod operand->dimensions.data(), armnn::DataType::Float32); - return { std::move(dequantizedBuffer), dequantizedBufferLength * sizeof(float), std::move(tensorInfo) }; + return { std::move(dequantizedBuffer), dequantizedBufferLength * sizeof(float), std::move(tensorInfo), 0 }; } - return { nullptr, 0, armnn::TensorInfo() }; + return { nullptr, 0, armnn::TensorInfo() , 0}; } template<typename HalPolicy, @@ -2385,16 +2392,21 @@ ConstTensorPin DequantizeAndMakeConstTensorPin(const Operation& operation, bool optional = false) { auto dequantized = DequantizeIfRequired<HalPolicy, Operation, Model>(operandIndex,operation, model, data); - if (std::get<1>(dequantized) == 0 && optional) + if (std::get<3>(dequantized) == -1) + { + // Return it as invalid, tensor with no values is not really an error + return ConstTensorPin(); + } + + if (std::get<1>(dequantized) == 0) { - // Optional tensor with no values is not really an error. Return it as invalid, but marked as optional - return ConstTensorPin(true); + return ConvertOperationInputToConstTensorPin<HalPolicy>( + operation, operandIndex, model, data, g_DontPermute, nullptr, optional); + } - return std::get<1>(dequantized) ? - ConstTensorPin(std::get<2>(dequantized), std::get<0>(dequantized).get(), - std::get<1>(dequantized), g_DontPermute): - ConvertOperationInputToConstTensorPin<HalPolicy>(operation, operandIndex, model, data); + return ConstTensorPin(std::get<2>(dequantized), std::get<0>(dequantized).get(), + std::get<1>(dequantized), g_DontPermute); } |