aboutsummaryrefslogtreecommitdiff
path: root/shim/sl/canonical/ConversionUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'shim/sl/canonical/ConversionUtils.cpp')
-rw-r--r--shim/sl/canonical/ConversionUtils.cpp39
1 files changed, 32 insertions, 7 deletions
diff --git a/shim/sl/canonical/ConversionUtils.cpp b/shim/sl/canonical/ConversionUtils.cpp
index 020410d30e..96a8ddca6a 100644
--- a/shim/sl/canonical/ConversionUtils.cpp
+++ b/shim/sl/canonical/ConversionUtils.cpp
@@ -151,7 +151,8 @@ ConstTensorPin ConvertOperandToConstTensorPin(const Operand& operand,
const ConversionData& data,
const armnn::PermutationVector& dimensionMappings,
const armnn::TensorShape* overrideTensorShape,
- bool optional)
+ bool optional,
+ const armnn::DataType* overrideDataType)
{
if (!IsOperandTypeSupportedForTensors(operand.type))
{
@@ -180,13 +181,18 @@ ConstTensorPin ConvertOperandToConstTensorPin(const Operand& operand,
armnn::TensorInfo tensorInfo = GetTensorInfoForOperand(operand);
- // Make sure isConstant flag is set.
- tensorInfo.SetConstant();
-
- if (overrideTensorShape != nullptr)
+ if (overrideTensorShape)
{
tensorInfo.SetShape(*overrideTensorShape);
}
+
+ if (overrideDataType)
+ {
+ tensorInfo.SetDataType(*overrideDataType);
+ }
+
+ // Make sure isConstant flag is set.
+ tensorInfo.SetConstant();
return ConstTensorPin(tensorInfo, valueStart, operand.location.length, dimensionMappings);
}
@@ -194,7 +200,8 @@ LayerInputHandle ConvertToLayerInputHandle(const Operation& operation,
uint32_t inputIndex,
const Model& model,
ConversionData& data,
- const armnn::PermutationVector& dimensionMappings)
+ const armnn::PermutationVector& dimensionMappings,
+ const LayerInputHandle* inputHandle)
{
const Operand* operand = GetInputOperand(operation, inputIndex, model);
@@ -268,8 +275,26 @@ LayerInputHandle ConvertToLayerInputHandle(const Operation& operation,
case OperandLifeTime::POINTER:
case OperandLifeTime::CONSTANT_REFERENCE:
{
+ auto constantTensorDataType = operandTensorInfo.GetDataType();
+ if (inputHandle)
+ {
+ if ((inputHandle->GetTensorInfo().GetDataType() == armnn::DataType::Float32
+ || inputHandle->GetTensorInfo().GetDataType() == armnn::DataType::Float16)
+ && (operandTensorInfo.GetDataType() == armnn::DataType::QAsymmU8
+ || operandTensorInfo.GetDataType() == armnn::DataType::QAsymmS8))
+ {
+ constantTensorDataType = inputHandle->GetTensorInfo().GetDataType();
+ }
+ }
+
// The tensor has an already known constant value, and can be converted into an ArmNN Constant layer.
- ConstTensorPin tensorPin = ConvertOperandToConstTensorPin(*operand, model, data, dimensionMappings);
+ ConstTensorPin tensorPin = ConvertOperandToConstTensorPin(*operand,
+ model,
+ data,
+ dimensionMappings,
+ nullptr,
+ false,
+ &constantTensorDataType);
if (tensorPin.IsValid())
{
bool isSupported = false;