From 65a1b1d600cbccf7269409cb7ca0947f0222cb8b Mon Sep 17 00:00:00 2001 From: Aron Virginas-Tar Date: Fri, 15 Nov 2019 15:59:51 +0000 Subject: IVGCVSW-4139 Fix regression in ConvertDequantize() * Removed TENSOR_QUANT8_SYMM from the list of generally supported tensor data types * Fixed tensor info in DequantizeIfRequired() for on the fly dequantized QSymm8 weights * Moved code for checking whether a Dequantize operator is linked to FullyConnected or Lstm weights from ConvertDequantize() into a separate function inside 1.2/HalPolicy.cpp Signed-off-by: Aron Virginas-Tar Change-Id: I19ea6f89a90f553a964b87d44f8ad8a064e96f7f --- 1.2/HalPolicy.cpp | 65 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) (limited to '1.2') diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp index f901a31b..c8e29688 100644 --- a/1.2/HalPolicy.cpp +++ b/1.2/HalPolicy.cpp @@ -23,6 +23,63 @@ namespace hal_1_2 using namespace armnn; +namespace +{ + +bool IsQSymmDequantizeForWeights(const Operation& operation, const Model& model) +{ + const Operand* operand = GetInputOperand(operation, 0, model); + if (!operand) + { + return false; + } + + if(!IsQSymm8(*operand)) + { + // Only QSymm8 weights are dequantized on the fly by the driver + return false; + } + + if (!IsOperandConstant(*operand)) + { + // Non-const input is not accepted for weights + return false; + } + + // Iterate through all the operations and find the operation feeding from the Dequantize output + const size_t outputIndex = operation.outputs[0]; + for (uint32_t operationIdx = 0; operationIdx < model.operations.size(); ++operationIdx) + { + const auto& operationIt = model.operations[operationIdx]; + switch (operationIt.type) + { + case HalPolicy::OperationType::FULLY_CONNECTED: + if (outputIndex == operationIt.inputs[1]) // Weights are bound to slot 1 + { + // If the output is going into the FC weights return true + return true; + } + break; + case HalPolicy::OperationType::LSTM: + for (size_t k = 0; k < operationIt.inputs.size(); ++k) + { + if (outputIndex == operationIt.inputs[k]) + { + // If the output is going into the LSTM weights return true + return true; + } + } + break; + default: + break; + } + } + + return false; +} + +} // anonymous namespace + bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model, ConversionData& data) { switch (operation.type) @@ -561,6 +618,14 @@ bool HalPolicy::ConvertDepthwiseConv2d(const Operation& operation, const Model& bool HalPolicy::ConvertDequantize(const Operation& operation, const Model& model, ConversionData& data) { ALOGV("hal_1_2::HalPolicy::ConvertDequantize()"); + + if (IsQSymmDequantizeForWeights(operation, model)) + { + // NOTE: QSymm8 weights are dequantized internally by the driver, + // therefore this type of Dequantize is implicitly supported + return true; + } + return ::ConvertDequantize(operation, model, data); } -- cgit v1.2.1