aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-11-15 15:59:51 +0000
committerAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-11-18 13:59:40 +0000
commit65a1b1d600cbccf7269409cb7ca0947f0222cb8b (patch)
treef1c536abaa78b9d57e315987736e3400bf7f976f
parent444268f77ca864857a52a2f2e1638ddfc7cfe7e6 (diff)
downloadandroid-nn-driver-65a1b1d600cbccf7269409cb7ca0947f0222cb8b.tar.gz
IVGCVSW-4139 Fix regression in ConvertDequantize()
* Removed TENSOR_QUANT8_SYMM from the list of generally supported tensor data types * Fixed tensor info in DequantizeIfRequired() for on the fly dequantized QSymm8 weights * Moved code for checking whether a Dequantize operator is linked to FullyConnected or Lstm weights from ConvertDequantize() into a separate function inside 1.2/HalPolicy.cpp Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> Change-Id: I19ea6f89a90f553a964b87d44f8ad8a064e96f7f
-rw-r--r--1.2/HalPolicy.cpp65
-rw-r--r--ConversionUtils.hpp115
-rw-r--r--Utils.cpp1
3 files changed, 96 insertions, 85 deletions
diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp
index f901a31b..c8e29688 100644
--- a/1.2/HalPolicy.cpp
+++ b/1.2/HalPolicy.cpp
@@ -23,6 +23,63 @@ namespace hal_1_2
using namespace armnn;
+namespace
+{
+
+bool IsQSymmDequantizeForWeights(const Operation& operation, const Model& model)
+{
+ const Operand* operand = GetInputOperand<hal_1_2::HalPolicy>(operation, 0, model);
+ if (!operand)
+ {
+ return false;
+ }
+
+ if(!IsQSymm8(*operand))
+ {
+ // Only QSymm8 weights are dequantized on the fly by the driver
+ return false;
+ }
+
+ if (!IsOperandConstant<hal_1_2::HalPolicy>(*operand))
+ {
+ // Non-const input is not accepted for weights
+ return false;
+ }
+
+ // Iterate through all the operations and find the operation feeding from the Dequantize output
+ const size_t outputIndex = operation.outputs[0];
+ for (uint32_t operationIdx = 0; operationIdx < model.operations.size(); ++operationIdx)
+ {
+ const auto& operationIt = model.operations[operationIdx];
+ switch (operationIt.type)
+ {
+ case HalPolicy::OperationType::FULLY_CONNECTED:
+ if (outputIndex == operationIt.inputs[1]) // Weights are bound to slot 1
+ {
+ // If the output is going into the FC weights return true
+ return true;
+ }
+ break;
+ case HalPolicy::OperationType::LSTM:
+ for (size_t k = 0; k < operationIt.inputs.size(); ++k)
+ {
+ if (outputIndex == operationIt.inputs[k])
+ {
+ // If the output is going into the LSTM weights return true
+ return true;
+ }
+ }
+ break;
+ default:
+ break;
+ }
+ }
+
+ return false;
+}
+
+} // anonymous namespace
+
bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model, ConversionData& data)
{
switch (operation.type)
@@ -561,6 +618,14 @@ bool HalPolicy::ConvertDepthwiseConv2d(const Operation& operation, const Model&
bool HalPolicy::ConvertDequantize(const Operation& operation, const Model& model, ConversionData& data)
{
ALOGV("hal_1_2::HalPolicy::ConvertDequantize()");
+
+ if (IsQSymmDequantizeForWeights(operation, model))
+ {
+ // NOTE: QSymm8 weights are dequantized internally by the driver,
+ // therefore this type of Dequantize is implicitly supported
+ return true;
+ }
+
return ::ConvertDequantize<hal_1_2::HalPolicy>(operation, model, data);
}
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index bcccd272..dbdba4cd 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -187,7 +187,6 @@ inline bool IsOperandTypeSupportedForTensors(V1_2::OperandType type)
type == V1_2::OperandType::TENSOR_FLOAT16 ||
type == V1_2::OperandType::TENSOR_FLOAT32 ||
type == V1_2::OperandType::TENSOR_QUANT8_ASYMM ||
- type == V1_2::OperandType::TENSOR_QUANT8_SYMM ||
type == V1_2::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL ||
type == V1_2::OperandType::TENSOR_QUANT16_SYMM ||
type == V1_2::OperandType::TENSOR_INT32;
@@ -715,6 +714,19 @@ bool GetOperandType(const HalOperation& operation,
}
template<typename HalPolicy,
+ typename HalOperand = typename HalPolicy::Operand>
+bool IsOperandConstant(const HalOperand& operand)
+{
+ using HalOperandLifeTime = typename HalPolicy::OperandLifeTime;
+
+ HalOperandLifeTime lifetime = operand.lifetime;
+
+ return lifetime == HalOperandLifeTime::CONSTANT_COPY ||
+ lifetime == HalOperandLifeTime::CONSTANT_REFERENCE ||
+ lifetime == HalOperandLifeTime::NO_VALUE;
+}
+
+template<typename HalPolicy,
typename HalOperand = typename HalPolicy::Operand,
typename HalModel = typename HalPolicy::Model>
ConstTensorPin ConvertOperandToConstTensorPin(const HalOperand& operand,
@@ -724,18 +736,13 @@ ConstTensorPin ConvertOperandToConstTensorPin(const HalOperand& operand,
const armnn::TensorShape* overrideTensorShape = nullptr,
bool optional = false)
{
- using HalOperandLifeTime = typename HalPolicy::OperandLifeTime;
-
if (!IsOperandTypeSupportedForTensors(operand.type))
{
Fail("%s: unsupported operand type for tensor %s", __func__, toString(operand.type).c_str());
return ConstTensorPin();
}
- if (!optional &&
- operand.lifetime != HalOperandLifeTime::CONSTANT_COPY &&
- operand.lifetime != HalOperandLifeTime::CONSTANT_REFERENCE &&
- operand.lifetime != HalOperandLifeTime::NO_VALUE)
+ if (!optional && !IsOperandConstant<HalPolicy>(operand))
{
Fail("%s: invalid operand lifetime: %s", __func__, toString(operand.lifetime).c_str());
return ConstTensorPin();
@@ -2125,30 +2132,6 @@ bool ConvertDepthwiseConv2d(const HalOperation& operation, const HalModel& model
}
template<typename HalPolicy,
- typename HalOperation = typename HalPolicy::Operation,
- typename HalModel = typename HalPolicy::Model>
-bool IsOperandConstant(const HalOperation& operation,
- uint32_t inputIndex,
- const HalModel& model,
- bool& isConstant)
-{
- using HalOperand = typename HalPolicy::Operand;
- using HalOperandLifeTime = typename HalPolicy::OperandLifeTime;
-
- const HalOperand* operand = GetInputOperand<HalPolicy>(operation, inputIndex, model);
- if (!operand)
- {
- return Fail("%s: invalid input operand at index %i", __func__, inputIndex);
- }
-
- isConstant = operand->lifetime == HalOperandLifeTime::CONSTANT_COPY ||
- operand->lifetime == HalOperandLifeTime::CONSTANT_REFERENCE ||
- operand->lifetime == HalOperandLifeTime::NO_VALUE;
-
- return true;
-}
-
-template<typename HalPolicy,
typename Operation = typename HalPolicy::Operation,
typename Model = typename HalPolicy::Model>
bool ConvertDequantize(const Operation& operation, const Model& model, ConversionData& data)
@@ -2167,43 +2150,6 @@ bool ConvertDequantize(const Operation& operation, const Model& model, Conversio
return Fail("%s: Operation has invalid outputs", __func__);
}
- // If the output is going into the FC weights and input is const just return true
- const size_t outputIndex = operation.outputs[0];
- bool input_is_constant = false;
- if (!IsOperandConstant<HalPolicy>(operation,0,model,input_is_constant) && input_is_constant)
- {
- return Fail("Non const input not supported");
- }
-
- // Iterate through the nodes and find the operation feeding from the Dequantize output operand
- for (uint32_t operationIdx = 0; operationIdx < model.operations.size(); ++operationIdx)
- {
- // Search for the FC op which consumes the output of Dequantize with index equal to outputIndex
- const auto& operationIt = model.operations[operationIdx];
- switch (operationIt.type)
- {
- case HalPolicy::OperationType::FULLY_CONNECTED:
- if (outputIndex == operationIt.inputs[1]) // Weights are bound to slot 1
- {
- // If the output is going into the FC weights and input is const just return true
- return true;
- }
- break;
- case HalPolicy::OperationType::LSTM:
- for (size_t k = 0; k < operationIt.inputs.size(); ++k)
- {
- if (outputIndex == operationIt.inputs[k])
- {
- // If the output is going into the LSTM weights and input is const just return true
- return true;
- }
- }
- break;
- default:
- break;
- }
- }
-
const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*outputOperand);
if (IsDynamicTensor(outputInfo))
{
@@ -2357,10 +2303,10 @@ template<typename HalPolicy,
std::tuple<std::unique_ptr<float[]>, size_t, armnn::TensorInfo>
DequantizeIfRequired(size_t operand_index, const Operation& operation, const Model& model, const ConversionData& data)
{
- using Operand = typename HalPolicy::Operand;
+ using HalOperand = typename HalPolicy::Operand;
- bool weights_constant = false;
- if (!(IsOperandConstant<HalPolicy>(operation, operand_index, model, weights_constant) && !weights_constant))
+ const HalOperand* weightsOperand = GetInputOperand<HalPolicy>(operation, operand_index, model);
+ if (!weightsOperand || IsOperandConstant<HalPolicy>(*weightsOperand))
{
return { nullptr, 0, armnn::TensorInfo() };
}
@@ -2371,30 +2317,27 @@ DequantizeIfRequired(size_t operand_index, const Operation& operation, const Mod
// Iterate over the nodes and find the previous operation which should be DEQUANTIZE
for (uint32_t operationIdx = 0; operationIdx < model.operations.size(); ++operationIdx)
{
- const auto& operationIt = model.operations[operationIdx];
- size_t outOpIndex = weightsInputIndex + 1;
-
// Search for the DEQUANTIZE op which has the operand with index equal to operandIndex
+ const auto& operationIt = model.operations[operationIdx];
if (operationIt.type != HalPolicy::OperationType::DEQUANTIZE)
{
continue;
}
- for (size_t i = 0; outOpIndex != weightsInputIndex && i < operation.outputs.size(); ++i)
+ size_t outOpIndex = weightsInputIndex + 1;
+ for (size_t i = 0; outOpIndex != weightsInputIndex && i < operationIt.outputs.size(); ++i)
{
outOpIndex = operationIt.outputs[i];
- break;
}
if (outOpIndex != weightsInputIndex)
{
- break;
+ continue;
}
- const Operand* operand = GetInputOperand<HalPolicy>(operationIt, 0, model);
+ const HalOperand* operand = GetInputOperand<HalPolicy>(operationIt, 0, model);
BOOST_ASSERT(operand);
- armnn::TensorInfo tensorInfo = GetTensorInfoForOperand(*operand);
if (!IsQSymm8(*operand))
{
// Only supporting dequantize from QSYMM8 to FLOAT
@@ -2411,7 +2354,8 @@ DequantizeIfRequired(size_t operand_index, const Operation& operation, const Mod
const uint8_t* quantizedBuffer = reinterpret_cast<const uint8_t*>(startValue);
size_t dequantizedBufferLength = operand->location.length;
- const float quantizationScale = tensorInfo.GetQuantizationScale();
+ const float quantizationScale = operand->scale;
+
auto dequantizedBuffer = std::make_unique<float[]>(dequantizedBufferLength + 1);
for (size_t i = 0; i < dequantizedBufferLength; ++i)
{
@@ -2420,7 +2364,11 @@ DequantizeIfRequired(size_t operand_index, const Operation& operation, const Mod
*dstPtr++ = quantizedBuffer[i] * quantizationScale;
}
- tensorInfo.SetDataType(armnn::DataType::Float32);
+ // Construct tensor info for dequantized ConstTensor
+ armnn::TensorInfo tensorInfo(operand->dimensions.size(),
+ operand->dimensions.data(),
+ armnn::DataType::Float32);
+
return { std::move(dequantizedBuffer), dequantizedBufferLength * sizeof(float), std::move(tensorInfo) };
}
@@ -2476,9 +2424,8 @@ bool ConvertFullyConnected(const Operation& operation, const Model& model, Conve
return Fail("%s: Dynamic output tensors are not supported", __func__);
}
- ConstTensorPin weightsPin = DequantizeAndMakeConstTensorPin<HalPolicy, Operation, Model>(operation, model, data, 1);
-
- ConstTensorPin biasPin = ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 2, model, data); // 1D
+ ConstTensorPin weightsPin = DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 1);
+ ConstTensorPin biasPin = ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 2, model, data); // 1D
if (!weightsPin.IsValid())
{
diff --git a/Utils.cpp b/Utils.cpp
index 555039ca..246d6415 100644
--- a/Utils.cpp
+++ b/Utils.cpp
@@ -127,7 +127,6 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_2::Operand& operand)
type = armnn::DataType::QuantizedSymm8PerAxis;
break;
case V1_2::OperandType::TENSOR_QUANT8_ASYMM:
- case V1_2::OperandType::TENSOR_QUANT8_SYMM:
type = armnn::DataType::QuantisedAsymm8;
break;
case V1_2::OperandType::TENSOR_QUANT16_SYMM: