aboutsummaryrefslogtreecommitdiff
path: root/ConversionUtils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'ConversionUtils.hpp')
-rw-r--r--ConversionUtils.hpp201
1 files changed, 190 insertions, 11 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index 1975434a..88c15375 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -187,6 +187,7 @@ inline bool IsOperandTypeSupportedForTensors(V1_2::OperandType type)
type == V1_2::OperandType::TENSOR_FLOAT16 ||
type == V1_2::OperandType::TENSOR_FLOAT32 ||
type == V1_2::OperandType::TENSOR_QUANT8_ASYMM ||
+ type == V1_2::OperandType::TENSOR_QUANT8_SYMM ||
type == V1_2::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL ||
type == V1_2::OperandType::TENSOR_QUANT16_SYMM ||
type == V1_2::OperandType::TENSOR_INT32;
@@ -646,8 +647,8 @@ const HalOperand* GetOutputOperand(const HalOperation& operation,
}
template<typename HalPolicy,
- typename HalOperand = typename HalPolicy::Operand,
- typename HalModel = typename HalPolicy::Model>
+ typename HalOperand = typename HalPolicy::Operand,
+ typename HalModel = typename HalPolicy::Model>
const void* GetOperandValueReadOnlyAddress(const HalOperand& operand,
const HalModel& model,
const ConversionData& data,
@@ -2118,6 +2119,30 @@ bool ConvertDepthwiseConv2d(const HalOperation& operation, const HalModel& model
}
template<typename HalPolicy,
+ typename HalOperation = typename HalPolicy::Operation,
+ typename HalModel = typename HalPolicy::Model>
+bool IsOperandConstant(const HalOperation& operation,
+ uint32_t inputIndex,
+ const HalModel& model,
+ bool& isConstant)
+{
+ using HalOperand = typename HalPolicy::Operand;
+ using HalOperandLifeTime = typename HalPolicy::OperandLifeTime;
+
+ const HalOperand* operand = GetInputOperand<HalPolicy>(operation, inputIndex, model);
+ if (!operand)
+ {
+ return Fail("%s: invalid input operand at index %i", __func__, inputIndex);
+ }
+
+ isConstant = operand->lifetime == HalOperandLifeTime::CONSTANT_COPY ||
+ operand->lifetime == HalOperandLifeTime::CONSTANT_REFERENCE ||
+ operand->lifetime == HalOperandLifeTime::NO_VALUE;
+
+ return true;
+}
+
+template<typename HalPolicy,
typename Operation = typename HalPolicy::Operation,
typename Model = typename HalPolicy::Model>
bool ConvertDequantize(const Operation& operation, const Model& model, ConversionData& data)
@@ -2136,6 +2161,43 @@ bool ConvertDequantize(const Operation& operation, const Model& model, Conversio
return Fail("%s: Operation has invalid outputs", __func__);
}
+ // If the output is going into the FC weights and input is const just return true
+ const size_t outputIndex = operation.outputs[0];
+ bool input_is_constant = false;
+ if (!IsOperandConstant<HalPolicy>(operation,0,model,input_is_constant) && input_is_constant)
+ {
+ return Fail("Non const input not supported");
+ }
+
+ // Iterate through the nodes and find the operation feeding from the Dequantize output operand
+ for (uint32_t operationIdx = 0; operationIdx < model.operations.size(); ++operationIdx)
+ {
+ // Search for the FC op which consumes the output of Dequantize with index equal to outputIndex
+ const auto& operationIt = model.operations[operationIdx];
+ switch (operationIt.type)
+ {
+ case HalPolicy::OperationType::FULLY_CONNECTED:
+ if (outputIndex == operationIt.inputs[1]) // Weights are bound to slot 1
+ {
+ // If the output is going into the FC weights and input is const just return true
+ return true;
+ }
+ break;
+ case HalPolicy::OperationType::LSTM:
+ for (size_t k = 0; k < operationIt.inputs.size(); ++k)
+ {
+ if (outputIndex == operationIt.inputs[k])
+ {
+ // If the output is going into the LSTM weights and input is const just return true
+ return true;
+ }
+ }
+ break;
+ default:
+ break;
+ }
+ }
+
const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*outputOperand);
if (IsDynamicTensor(outputInfo))
{
@@ -2269,13 +2331,125 @@ bool ConvertFloor(const Operation& operation, const Model& model, ConversionData
return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data);
}
+inline bool IsQSymm8(const V1_0::Operand&)
+{
+ return false;
+}
+
+#ifdef ARMNN_ANDROID_NN_V1_2
+
+inline bool IsQSymm8(const V1_2::Operand& operand)
+{
+ return operand.type == V1_2::OperandType::TENSOR_QUANT8_SYMM;
+}
+
+#endif
+
template<typename HalPolicy,
typename Operation = typename HalPolicy::Operation,
typename Model = typename HalPolicy::Model>
-bool ConvertFullyConnected(const Operation& operation, const Model& model, ConversionData& data)
+std::tuple<std::unique_ptr<float[]>, size_t, armnn::TensorInfo>
+DequantizeIfRequired(size_t operand_index, const Operation& operation, const Model& model, const ConversionData& data)
{
using Operand = typename HalPolicy::Operand;
+ bool weights_constant = false;
+ if (!(IsOperandConstant<HalPolicy>(operation, operand_index, model, weights_constant) && !weights_constant))
+ {
+ return { nullptr, 0, armnn::TensorInfo() };
+ }
+
+ const size_t weightsInputIndex = operation.inputs[operand_index];
+
+ // The weights are a non const tensor, this indicates they might be the output of a dequantize op.
+ // Iterate over the nodes and find the previous operation which should be DEQUANTIZE
+ for (uint32_t operationIdx = 0; operationIdx < model.operations.size(); ++operationIdx)
+ {
+ const auto& operationIt = model.operations[operationIdx];
+ size_t outOpIndex = weightsInputIndex + 1;
+
+ // Search for the DEQUANTIZE op which has the operand with index equal to operandIndex
+ if (operationIt.type != HalPolicy::OperationType::DEQUANTIZE)
+ {
+ continue;
+ }
+
+ for (size_t i = 0; outOpIndex != weightsInputIndex && i < operation.outputs.size(); ++i)
+ {
+ outOpIndex = operationIt.outputs[i];
+ break;
+ }
+
+ if (outOpIndex != weightsInputIndex)
+ {
+ break;
+ }
+
+ const Operand* operand = GetInputOperand<HalPolicy>(operationIt, 0, model);
+ BOOST_ASSERT(operand);
+
+ armnn::TensorInfo tensorInfo = GetTensorInfoForOperand(*operand);
+ if (!IsQSymm8(*operand))
+ {
+ // Only supporting dequantize from QSYMM8 to FLOAT
+ break;
+ }
+
+ // Allocate a new buffer for the dequantized data and manually dequantize
+ const void* startValue = GetOperandValueReadOnlyAddress<HalPolicy>(*operand, model, data);
+ if (!startValue)
+ {
+ // Failed to get the operand address
+ break;
+ }
+
+ const uint8_t* quantizedBuffer = reinterpret_cast<const uint8_t*>(startValue);
+ size_t dequantizedBufferLength = operand->location.length;
+ const float quantizationScale = tensorInfo.GetQuantizationScale();
+ auto dequantizedBuffer = std::make_unique<float[]>(dequantizedBufferLength + 1);
+ for (size_t i = 0; i < dequantizedBufferLength; ++i)
+ {
+ float* dstPtr = dequantizedBuffer.get();
+ BOOST_ASSERT(dstPtr);
+ *dstPtr++ = quantizedBuffer[i] * quantizationScale;
+ }
+
+ tensorInfo.SetDataType(armnn::DataType::Float32);
+ return { std::move(dequantizedBuffer), dequantizedBufferLength * sizeof(float), std::move(tensorInfo) };
+ }
+
+ return { nullptr, 0, armnn::TensorInfo() };
+}
+
+template<typename HalPolicy,
+ typename Operation = typename HalPolicy::Operation,
+ typename Model = typename HalPolicy::Model>
+ConstTensorPin DequantizeAndMakeConstTensorPin(const Operation& operation,
+ const Model& model,
+ const ConversionData& data,
+ size_t operandIndex,
+ bool optional = false)
+{
+ auto dequantized = DequantizeIfRequired<HalPolicy, Operation, Model>(operandIndex,operation, model, data);
+ if (std::get<1>(dequantized) == 0 && optional)
+ {
+ // Optional tensor with no values is not really an error. Return it as invalid, but marked as optional
+ return ConstTensorPin(true);
+ }
+
+ return std::get<1>(dequantized) ?
+ ConstTensorPin(std::get<2>(dequantized), std::get<0>(dequantized).get(),
+ std::get<1>(dequantized), g_DontPermute):
+ ConvertOperationInputToConstTensorPin<HalPolicy>(operation, operandIndex, model, data);
+}
+
+
+template<typename HalPolicy,
+ typename Operation = typename HalPolicy::Operation,
+ typename Model = typename HalPolicy::Model>
+bool ConvertFullyConnected(const Operation& operation, const Model& model, ConversionData& data)
+{
+ using Operand = typename HalPolicy::Operand;
LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
if (!input.IsValid())
{
@@ -2296,15 +2470,18 @@ bool ConvertFullyConnected(const Operation& operation, const Model& model, Conve
return Fail("%s: Dynamic output tensors are not supported", __func__);
}
- // ArmNN does not currently support non-fixed weights or bias
- ConstTensorPin weightsPin =
- ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 1, model, data); // 2D
- ConstTensorPin biasPin =
- ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 2, model, data); // 1D
+ ConstTensorPin weightsPin = DequantizeAndMakeConstTensorPin<HalPolicy, Operation, Model>(operation, model, data, 1);
- if (!weightsPin.IsValid() || !biasPin.IsValid())
+ ConstTensorPin biasPin = ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 2, model, data); // 1D
+
+ if (!weightsPin.IsValid())
{
- return Fail("%s: Operation has invalid inputs", __func__);
+ return Fail("%s: Operation has invalid weights", __func__);
+ }
+
+ if (!biasPin.IsValid())
+ {
+ return Fail("%s: Operation has invalid bias", __func__);
}
armnn::ConstTensor weights = weightsPin.GetConstTensor();
@@ -2314,7 +2491,9 @@ bool ConvertFullyConnected(const Operation& operation, const Model& model, Conve
try
{
reshapedInfo.SetShape(FlattenFullyConnectedInput(inputInfo.GetShape(), weights.GetInfo().GetShape()));
- } catch (const std::exception &e) {
+ }
+ catch (const std::exception& e)
+ {
return Fail("%s: %s", __func__, e.what());
}