diff options
Diffstat (limited to 'src/backends/aclCommon/ArmComputeTensorUtils.cpp')
-rw-r--r-- | src/backends/aclCommon/ArmComputeTensorUtils.cpp | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.cpp b/src/backends/aclCommon/ArmComputeTensorUtils.cpp index 1cad92f58a..04202ada90 100644 --- a/src/backends/aclCommon/ArmComputeTensorUtils.cpp +++ b/src/backends/aclCommon/ArmComputeTensorUtils.cpp @@ -13,7 +13,7 @@ namespace armnn namespace armcomputetensorutils { -arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType) +arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType, bool multiScales) { switch(dataType) { @@ -28,9 +28,13 @@ arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType) case armnn::DataType::QSymmS16: return arm_compute::DataType::QSYMM16; case armnn::DataType::QSymmS8: - return arm_compute::DataType::QSYMM8; + { + return multiScales ? arm_compute::DataType::QSYMM8_PER_CHANNEL : arm_compute::DataType::QSYMM8; + } + ARMNN_NO_DEPRECATE_WARN_BEGIN case armnn::DataType::QuantizedSymm8PerAxis: return arm_compute::DataType::QSYMM8_PER_CHANNEL; + ARMNN_NO_DEPRECATE_WARN_END case armnn::DataType::Signed32: return arm_compute::DataType::S32; default: @@ -109,10 +113,11 @@ arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& te // ARM Compute Tensor and CLTensor allocators. arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo) { + bool multiScales = tensorInfo.HasMultipleQuantizationScales(); const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape()); - const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType()); + const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType(), multiScales); - const arm_compute::QuantizationInfo aclQuantizationInfo = tensorInfo.HasMultipleQuantizationScales() ? + const arm_compute::QuantizationInfo aclQuantizationInfo = multiScales ? arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScales()) : arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScale(), tensorInfo.GetQuantizationOffset()); |