aboutsummaryrefslogtreecommitdiff
path: root/src/backends/aclCommon/ArmComputeTensorUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/aclCommon/ArmComputeTensorUtils.cpp')
-rw-r--r--src/backends/aclCommon/ArmComputeTensorUtils.cpp13
1 files changed, 9 insertions, 4 deletions
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.cpp b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
index 1cad92f58a..04202ada90 100644
--- a/src/backends/aclCommon/ArmComputeTensorUtils.cpp
+++ b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
@@ -13,7 +13,7 @@ namespace armnn
namespace armcomputetensorutils
{
-arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType)
+arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType, bool multiScales)
{
switch(dataType)
{
@@ -28,9 +28,13 @@ arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType)
case armnn::DataType::QSymmS16:
return arm_compute::DataType::QSYMM16;
case armnn::DataType::QSymmS8:
- return arm_compute::DataType::QSYMM8;
+ {
+ return multiScales ? arm_compute::DataType::QSYMM8_PER_CHANNEL : arm_compute::DataType::QSYMM8;
+ }
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
case armnn::DataType::QuantizedSymm8PerAxis:
return arm_compute::DataType::QSYMM8_PER_CHANNEL;
+ ARMNN_NO_DEPRECATE_WARN_END
case armnn::DataType::Signed32:
return arm_compute::DataType::S32;
default:
@@ -109,10 +113,11 @@ arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& te
// ARM Compute Tensor and CLTensor allocators.
arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo)
{
+ bool multiScales = tensorInfo.HasMultipleQuantizationScales();
const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape());
- const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType());
+ const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType(), multiScales);
- const arm_compute::QuantizationInfo aclQuantizationInfo = tensorInfo.HasMultipleQuantizationScales() ?
+ const arm_compute::QuantizationInfo aclQuantizationInfo = multiScales ?
arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScales()) :
arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScale(), tensorInfo.GetQuantizationOffset());