diff options
Diffstat (limited to 'shim/sl/canonical/ConversionUtils.cpp')
-rw-r--r-- | shim/sl/canonical/ConversionUtils.cpp | 39 |
1 files changed, 32 insertions, 7 deletions
diff --git a/shim/sl/canonical/ConversionUtils.cpp b/shim/sl/canonical/ConversionUtils.cpp index 020410d30e..96a8ddca6a 100644 --- a/shim/sl/canonical/ConversionUtils.cpp +++ b/shim/sl/canonical/ConversionUtils.cpp @@ -151,7 +151,8 @@ ConstTensorPin ConvertOperandToConstTensorPin(const Operand& operand, const ConversionData& data, const armnn::PermutationVector& dimensionMappings, const armnn::TensorShape* overrideTensorShape, - bool optional) + bool optional, + const armnn::DataType* overrideDataType) { if (!IsOperandTypeSupportedForTensors(operand.type)) { @@ -180,13 +181,18 @@ ConstTensorPin ConvertOperandToConstTensorPin(const Operand& operand, armnn::TensorInfo tensorInfo = GetTensorInfoForOperand(operand); - // Make sure isConstant flag is set. - tensorInfo.SetConstant(); - - if (overrideTensorShape != nullptr) + if (overrideTensorShape) { tensorInfo.SetShape(*overrideTensorShape); } + + if (overrideDataType) + { + tensorInfo.SetDataType(*overrideDataType); + } + + // Make sure isConstant flag is set. + tensorInfo.SetConstant(); return ConstTensorPin(tensorInfo, valueStart, operand.location.length, dimensionMappings); } @@ -194,7 +200,8 @@ LayerInputHandle ConvertToLayerInputHandle(const Operation& operation, uint32_t inputIndex, const Model& model, ConversionData& data, - const armnn::PermutationVector& dimensionMappings) + const armnn::PermutationVector& dimensionMappings, + const LayerInputHandle* inputHandle) { const Operand* operand = GetInputOperand(operation, inputIndex, model); @@ -268,8 +275,26 @@ LayerInputHandle ConvertToLayerInputHandle(const Operation& operation, case OperandLifeTime::POINTER: case OperandLifeTime::CONSTANT_REFERENCE: { + auto constantTensorDataType = operandTensorInfo.GetDataType(); + if (inputHandle) + { + if ((inputHandle->GetTensorInfo().GetDataType() == armnn::DataType::Float32 + || inputHandle->GetTensorInfo().GetDataType() == armnn::DataType::Float16) + && (operandTensorInfo.GetDataType() == armnn::DataType::QAsymmU8 + || operandTensorInfo.GetDataType() == armnn::DataType::QAsymmS8)) + { + constantTensorDataType = inputHandle->GetTensorInfo().GetDataType(); + } + } + // The tensor has an already known constant value, and can be converted into an ArmNN Constant layer. - ConstTensorPin tensorPin = ConvertOperandToConstTensorPin(*operand, model, data, dimensionMappings); + ConstTensorPin tensorPin = ConvertOperandToConstTensorPin(*operand, + model, + data, + dimensionMappings, + nullptr, + false, + &constantTensorDataType); if (tensorPin.IsValid()) { bool isSupported = false; |