aboutsummaryrefslogtreecommitdiff
path: root/src/backends/aclCommon/ArmComputeTensorUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/aclCommon/ArmComputeTensorUtils.cpp')
-rw-r--r--src/backends/aclCommon/ArmComputeTensorUtils.cpp19
1 files changed, 8 insertions, 11 deletions
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.cpp b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
index c7d250a706..b2955b9259 100644
--- a/src/backends/aclCommon/ArmComputeTensorUtils.cpp
+++ b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
@@ -106,9 +106,11 @@ arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& te
arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo)
{
const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape());
- const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType());
- const arm_compute::QuantizationInfo aclQuantizationInfo(tensorInfo.GetQuantizationScale(),
- tensorInfo.GetQuantizationOffset());
+ const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType());
+
+ const arm_compute::QuantizationInfo aclQuantizationInfo = tensorInfo.HasMultipleQuantizationScales() ?
+ arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScales()) :
+ arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScale(), tensorInfo.GetQuantizationOffset());
return arm_compute::TensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo);
}
@@ -116,15 +118,10 @@ arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tenso
arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo,
armnn::DataLayout dataLayout)
{
- const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape());
- const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType());
- const arm_compute::QuantizationInfo aclQuantizationInfo(tensorInfo.GetQuantizationScale(),
- tensorInfo.GetQuantizationOffset());
-
- arm_compute::TensorInfo clTensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo);
- clTensorInfo.set_data_layout(ConvertDataLayout(dataLayout));
+ arm_compute::TensorInfo aclTensorInfo = BuildArmComputeTensorInfo(tensorInfo);
+ aclTensorInfo.set_data_layout(ConvertDataLayout(dataLayout));
- return clTensorInfo;
+ return aclTensorInfo;
}
arm_compute::DataLayout ConvertDataLayout(armnn::DataLayout dataLayout)