From d08119479be236043205aab5b23d8a29bc3e8768 Mon Sep 17 00:00:00 2001 From: Sadik Armagan Date: Mon, 18 Nov 2019 17:11:21 +0000 Subject: IVGCVSW-4137 Failing LSTM android-nn-driver Unit Tests in HAL 1.2 * Fixed for failing LSTM unit tests Signed-off-by: Sadik Armagan Change-Id: I773c5227bc8d5606924cc0472c51172476773056 --- ConversionUtils.hpp | 36 ++++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp index dbdba4cd..0637c2b5 100644 --- a/ConversionUtils.hpp +++ b/ConversionUtils.hpp @@ -2300,15 +2300,22 @@ inline bool IsQSymm8(const V1_2::Operand& operand) template -std::tuple, size_t, armnn::TensorInfo> +std::tuple, size_t, armnn::TensorInfo, int> DequantizeIfRequired(size_t operand_index, const Operation& operation, const Model& model, const ConversionData& data) { using HalOperand = typename HalPolicy::Operand; const HalOperand* weightsOperand = GetInputOperand(operation, operand_index, model); - if (!weightsOperand || IsOperandConstant(*weightsOperand)) + if (!weightsOperand) { - return { nullptr, 0, armnn::TensorInfo() }; + // Invalid Operand will return with error code '-1' + return { nullptr, 0, armnn::TensorInfo(), -1 }; + } + + if (IsOperandConstant(*weightsOperand)) + { + // Weights are already constant + return { nullptr, 0, armnn::TensorInfo(), 0 }; } const size_t weightsInputIndex = operation.inputs[operand_index]; @@ -2369,10 +2376,10 @@ DequantizeIfRequired(size_t operand_index, const Operation& operation, const Mod operand->dimensions.data(), armnn::DataType::Float32); - return { std::move(dequantizedBuffer), dequantizedBufferLength * sizeof(float), std::move(tensorInfo) }; + return { std::move(dequantizedBuffer), dequantizedBufferLength * sizeof(float), std::move(tensorInfo), 0 }; } - return { nullptr, 0, armnn::TensorInfo() }; + return { nullptr, 0, armnn::TensorInfo() , 0}; } template(operandIndex,operation, model, data); - if (std::get<1>(dequantized) == 0 && optional) + if (std::get<3>(dequantized) == -1) + { + // Return it as invalid, tensor with no values is not really an error + return ConstTensorPin(); + } + + if (std::get<1>(dequantized) == 0) { - // Optional tensor with no values is not really an error. Return it as invalid, but marked as optional - return ConstTensorPin(true); + return ConvertOperationInputToConstTensorPin( + operation, operandIndex, model, data, g_DontPermute, nullptr, optional); + } - return std::get<1>(dequantized) ? - ConstTensorPin(std::get<2>(dequantized), std::get<0>(dequantized).get(), - std::get<1>(dequantized), g_DontPermute): - ConvertOperationInputToConstTensorPin(operation, operandIndex, model, data); + return ConstTensorPin(std::get<2>(dequantized), std::get<0>(dequantized).get(), + std::get<1>(dequantized), g_DontPermute); } -- cgit v1.2.1