aboutsummaryrefslogtreecommitdiff
path: root/1.2
diff options
context:
space:
mode:
authorAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-11-15 15:59:51 +0000
committerAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-11-18 13:59:40 +0000
commit65a1b1d600cbccf7269409cb7ca0947f0222cb8b (patch)
treef1c536abaa78b9d57e315987736e3400bf7f976f /1.2
parent444268f77ca864857a52a2f2e1638ddfc7cfe7e6 (diff)
downloadandroid-nn-driver-65a1b1d600cbccf7269409cb7ca0947f0222cb8b.tar.gz
IVGCVSW-4139 Fix regression in ConvertDequantize()
* Removed TENSOR_QUANT8_SYMM from the list of generally supported tensor data types * Fixed tensor info in DequantizeIfRequired() for on the fly dequantized QSymm8 weights * Moved code for checking whether a Dequantize operator is linked to FullyConnected or Lstm weights from ConvertDequantize() into a separate function inside 1.2/HalPolicy.cpp Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> Change-Id: I19ea6f89a90f553a964b87d44f8ad8a064e96f7f
Diffstat (limited to '1.2')
-rw-r--r--1.2/HalPolicy.cpp65
1 files changed, 65 insertions, 0 deletions
diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp
index f901a31b..c8e29688 100644
--- a/1.2/HalPolicy.cpp
+++ b/1.2/HalPolicy.cpp
@@ -23,6 +23,63 @@ namespace hal_1_2
using namespace armnn;
+namespace
+{
+
+bool IsQSymmDequantizeForWeights(const Operation& operation, const Model& model)
+{
+ const Operand* operand = GetInputOperand<hal_1_2::HalPolicy>(operation, 0, model);
+ if (!operand)
+ {
+ return false;
+ }
+
+ if(!IsQSymm8(*operand))
+ {
+ // Only QSymm8 weights are dequantized on the fly by the driver
+ return false;
+ }
+
+ if (!IsOperandConstant<hal_1_2::HalPolicy>(*operand))
+ {
+ // Non-const input is not accepted for weights
+ return false;
+ }
+
+ // Iterate through all the operations and find the operation feeding from the Dequantize output
+ const size_t outputIndex = operation.outputs[0];
+ for (uint32_t operationIdx = 0; operationIdx < model.operations.size(); ++operationIdx)
+ {
+ const auto& operationIt = model.operations[operationIdx];
+ switch (operationIt.type)
+ {
+ case HalPolicy::OperationType::FULLY_CONNECTED:
+ if (outputIndex == operationIt.inputs[1]) // Weights are bound to slot 1
+ {
+ // If the output is going into the FC weights return true
+ return true;
+ }
+ break;
+ case HalPolicy::OperationType::LSTM:
+ for (size_t k = 0; k < operationIt.inputs.size(); ++k)
+ {
+ if (outputIndex == operationIt.inputs[k])
+ {
+ // If the output is going into the LSTM weights return true
+ return true;
+ }
+ }
+ break;
+ default:
+ break;
+ }
+ }
+
+ return false;
+}
+
+} // anonymous namespace
+
bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model, ConversionData& data)
{
switch (operation.type)
@@ -561,6 +618,14 @@ bool HalPolicy::ConvertDepthwiseConv2d(const Operation& operation, const Model&
bool HalPolicy::ConvertDequantize(const Operation& operation, const Model& model, ConversionData& data)
{
ALOGV("hal_1_2::HalPolicy::ConvertDequantize()");
+
+ if (IsQSymmDequantizeForWeights(operation, model))
+ {
+ // NOTE: QSymm8 weights are dequantized internally by the driver,
+ // therefore this type of Dequantize is implicitly supported
+ return true;
+ }
+
return ::ConvertDequantize<hal_1_2::HalPolicy>(operation, model, data);
}