diff options
author | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2019-11-15 15:59:51 +0000 |
---|---|---|
committer | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2019-11-18 13:59:40 +0000 |
commit | 65a1b1d600cbccf7269409cb7ca0947f0222cb8b (patch) | |
tree | f1c536abaa78b9d57e315987736e3400bf7f976f /1.2/HalPolicy.cpp | |
parent | 444268f77ca864857a52a2f2e1638ddfc7cfe7e6 (diff) | |
download | android-nn-driver-65a1b1d600cbccf7269409cb7ca0947f0222cb8b.tar.gz |
IVGCVSW-4139 Fix regression in ConvertDequantize()
* Removed TENSOR_QUANT8_SYMM from the list of generally supported
tensor data types
* Fixed tensor info in DequantizeIfRequired() for on the fly
dequantized QSymm8 weights
* Moved code for checking whether a Dequantize operator is linked
to FullyConnected or Lstm weights from ConvertDequantize() into
a separate function inside 1.2/HalPolicy.cpp
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Change-Id: I19ea6f89a90f553a964b87d44f8ad8a064e96f7f
Diffstat (limited to '1.2/HalPolicy.cpp')
-rw-r--r-- | 1.2/HalPolicy.cpp | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp index f901a31b..c8e29688 100644 --- a/1.2/HalPolicy.cpp +++ b/1.2/HalPolicy.cpp @@ -23,6 +23,63 @@ namespace hal_1_2 using namespace armnn; +namespace +{ + +bool IsQSymmDequantizeForWeights(const Operation& operation, const Model& model) +{ + const Operand* operand = GetInputOperand<hal_1_2::HalPolicy>(operation, 0, model); + if (!operand) + { + return false; + } + + if(!IsQSymm8(*operand)) + { + // Only QSymm8 weights are dequantized on the fly by the driver + return false; + } + + if (!IsOperandConstant<hal_1_2::HalPolicy>(*operand)) + { + // Non-const input is not accepted for weights + return false; + } + + // Iterate through all the operations and find the operation feeding from the Dequantize output + const size_t outputIndex = operation.outputs[0]; + for (uint32_t operationIdx = 0; operationIdx < model.operations.size(); ++operationIdx) + { + const auto& operationIt = model.operations[operationIdx]; + switch (operationIt.type) + { + case HalPolicy::OperationType::FULLY_CONNECTED: + if (outputIndex == operationIt.inputs[1]) // Weights are bound to slot 1 + { + // If the output is going into the FC weights return true + return true; + } + break; + case HalPolicy::OperationType::LSTM: + for (size_t k = 0; k < operationIt.inputs.size(); ++k) + { + if (outputIndex == operationIt.inputs[k]) + { + // If the output is going into the LSTM weights return true + return true; + } + } + break; + default: + break; + } + } + + return false; +} + +} // anonymous namespace + bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model, ConversionData& data) { switch (operation.type) @@ -561,6 +618,14 @@ bool HalPolicy::ConvertDepthwiseConv2d(const Operation& operation, const Model& bool HalPolicy::ConvertDequantize(const Operation& operation, const Model& model, ConversionData& data) { ALOGV("hal_1_2::HalPolicy::ConvertDequantize()"); + + if (IsQSymmDequantizeForWeights(operation, model)) + { + // NOTE: QSymm8 weights are dequantized internally by the driver, + // therefore this type of Dequantize is implicitly supported + return true; + } + return ::ConvertDequantize<hal_1_2::HalPolicy>(operation, model, data); } |