aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDerek Lamberti <derek.lamberti@arm.com>2020-01-22 15:37:29 +0000
committerDerek Lamberti <derek.lamberti@arm.com>2020-01-24 09:21:47 +0000
commitd466a54e79560f0ccacc6b13cd64e08defbac47c (patch)
treed8f8d5226d71178aed32c6ad407570071e55dded
parent4a3c61091037e7e86e8b03bb060d8c1ab82731a9 (diff)
downloadarmnn-d466a54e79560f0ccacc6b13cd64e08defbac47c.tar.gz
IVGCVSW-4370 Deprecate DataType::QuantizedSymm8PerAxis
!android-nn-driver:2622 Change-Id: If99d3eff71ff66ba28af1e5af248299fe04511b9 Signed-off-by: Derek Lamberti <derek.lamberti@arm.com>
-rw-r--r--include/armnn/Deprecated.hpp2
-rw-r--r--include/armnn/Types.hpp2
-rw-r--r--include/armnn/TypesUtils.hpp12
-rw-r--r--src/armnn/CompatibleTypes.hpp2
-rw-r--r--src/armnn/Tensor.cpp2
-rw-r--r--src/backends/aclCommon/ArmComputeTensorUtils.cpp13
-rw-r--r--src/backends/aclCommon/ArmComputeTensorUtils.hpp2
-rw-r--r--src/backends/backendsCommon/LayerSupportRules.hpp8
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp21
-rw-r--r--src/backends/backendsCommon/WorkloadUtils.cpp6
-rw-r--r--src/backends/backendsCommon/test/WorkloadDataValidation.cpp2
-rw-r--r--src/backends/backendsCommon/test/layerTests/Conv2dTestImpl.cpp4
-rw-r--r--src/backends/backendsCommon/test/layerTests/TransposeConvolution2dTestImpl.cpp2
-rw-r--r--src/backends/cl/workloads/ClWorkloadUtils.hpp6
-rw-r--r--src/backends/neon/workloads/NeonWorkloadUtils.hpp6
-rw-r--r--src/backends/reference/RefLayerSupport.cpp31
-rw-r--r--src/backends/reference/workloads/Decoders.hpp21
-rw-r--r--src/backends/reference/workloads/Encoders.hpp21
18 files changed, 129 insertions, 34 deletions
diff --git a/include/armnn/Deprecated.hpp b/include/armnn/Deprecated.hpp
index 73871772bc..2b9240fbc4 100644
--- a/include/armnn/Deprecated.hpp
+++ b/include/armnn/Deprecated.hpp
@@ -42,7 +42,7 @@ ARMNN_NO_DEPRECATE_WARN_END
#define ARMNN_DEPRECATED [[deprecated]]
#define ARMNN_DEPRECATED_MSG(message) [[deprecated(message)]]
-#if defined(__GNUC__) && (__GNUC__ <= 6)
+#if defined(__GNUC__) && (__GNUC__ < 6)
# define ARMNN_DEPRECATED_ENUM
# define ARMNN_DEPRECATED_ENUM_MSG(message)
#else
diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp
index 1ab5660109..b0f5a08bd3 100644
--- a/include/armnn/Types.hpp
+++ b/include/armnn/Types.hpp
@@ -37,7 +37,7 @@ enum class DataType
Signed32 = 3,
Boolean = 4,
QSymmS16 = 5,
- QuantizedSymm8PerAxis = 6,
+ QuantizedSymm8PerAxis ARMNN_DEPRECATED_ENUM_MSG("Per Axis property inferred by number of scales in TensorInfo") = 6,
QSymmS8 = 7,
QuantisedAsymm8 ARMNN_DEPRECATED_ENUM_MSG("Use DataType::QAsymmU8 instead.") = QAsymmU8,
diff --git a/include/armnn/TypesUtils.hpp b/include/armnn/TypesUtils.hpp
index 790f57a432..257e39f363 100644
--- a/include/armnn/TypesUtils.hpp
+++ b/include/armnn/TypesUtils.hpp
@@ -119,8 +119,10 @@ constexpr unsigned int GetDataTypeSize(DataType dataType)
case DataType::Signed32: return 4U;
case DataType::QAsymmU8: return 1U;
case DataType::QSymmS8: return 1U;
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
case DataType::QuantizedSymm8PerAxis: return 1U;
- case DataType::QSymmS16: return 2U;
+ ARMNN_NO_DEPRECATE_WARN_END
+ case DataType::QSymmS16: return 2U;
case DataType::Boolean: return 1U;
default: return 0U;
}
@@ -167,8 +169,10 @@ constexpr const char* GetDataTypeName(DataType dataType)
case DataType::Float32: return "Float32";
case DataType::QAsymmU8: return "QAsymmU8";
case DataType::QSymmS8: return "QSymmS8";
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
case DataType::QuantizedSymm8PerAxis: return "QSymm8PerAxis";
- case DataType::QSymmS16: return "QSymm16";
+ ARMNN_NO_DEPRECATE_WARN_END
+ case DataType::QSymmS16: return "QSymm16";
case DataType::Signed32: return "Signed32";
case DataType::Boolean: return "Boolean";
@@ -230,10 +234,12 @@ constexpr bool IsQuantizedType()
constexpr bool IsQuantizedType(DataType dataType)
{
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
return dataType == DataType::QAsymmU8 ||
dataType == DataType::QSymmS8 ||
- dataType == DataType::QSymmS16 ||
+ dataType == DataType::QSymmS16 ||
dataType == DataType::QuantizedSymm8PerAxis;
+ ARMNN_NO_DEPRECATE_WARN_END
}
inline std::ostream& operator<<(std::ostream& os, Status stat)
diff --git a/src/armnn/CompatibleTypes.hpp b/src/armnn/CompatibleTypes.hpp
index bca092ca0c..8603a1bc38 100644
--- a/src/armnn/CompatibleTypes.hpp
+++ b/src/armnn/CompatibleTypes.hpp
@@ -38,7 +38,9 @@ inline bool CompatibleTypes<uint8_t>(DataType dataType)
template<>
inline bool CompatibleTypes<int8_t>(DataType dataType)
{
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
return dataType == DataType::QSymmS8 || dataType == DataType::QuantizedSymm8PerAxis;
+ ARMNN_NO_DEPRECATE_WARN_END
}
template<>
diff --git a/src/armnn/Tensor.cpp b/src/armnn/Tensor.cpp
index 8eebc43cb5..aeb7ab5fdd 100644
--- a/src/armnn/Tensor.cpp
+++ b/src/armnn/Tensor.cpp
@@ -289,7 +289,7 @@ void TensorInfo::SetQuantizationDim(const Optional<unsigned int>& quantizationDi
bool TensorInfo::IsQuantized() const
{
- return m_DataType == DataType::QAsymmU8 || m_DataType == DataType::QSymmS16;
+ return IsQuantizedType(m_DataType);
}
// ---
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.cpp b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
index 1cad92f58a..04202ada90 100644
--- a/src/backends/aclCommon/ArmComputeTensorUtils.cpp
+++ b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
@@ -13,7 +13,7 @@ namespace armnn
namespace armcomputetensorutils
{
-arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType)
+arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType, bool multiScales)
{
switch(dataType)
{
@@ -28,9 +28,13 @@ arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType)
case armnn::DataType::QSymmS16:
return arm_compute::DataType::QSYMM16;
case armnn::DataType::QSymmS8:
- return arm_compute::DataType::QSYMM8;
+ {
+ return multiScales ? arm_compute::DataType::QSYMM8_PER_CHANNEL : arm_compute::DataType::QSYMM8;
+ }
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
case armnn::DataType::QuantizedSymm8PerAxis:
return arm_compute::DataType::QSYMM8_PER_CHANNEL;
+ ARMNN_NO_DEPRECATE_WARN_END
case armnn::DataType::Signed32:
return arm_compute::DataType::S32;
default:
@@ -109,10 +113,11 @@ arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& te
// ARM Compute Tensor and CLTensor allocators.
arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo)
{
+ bool multiScales = tensorInfo.HasMultipleQuantizationScales();
const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape());
- const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType());
+ const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType(), multiScales);
- const arm_compute::QuantizationInfo aclQuantizationInfo = tensorInfo.HasMultipleQuantizationScales() ?
+ const arm_compute::QuantizationInfo aclQuantizationInfo = multiScales ?
arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScales()) :
arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScale(), tensorInfo.GetQuantizationOffset());
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.hpp b/src/backends/aclCommon/ArmComputeTensorUtils.hpp
index 3fc6818b0d..01d1dea53d 100644
--- a/src/backends/aclCommon/ArmComputeTensorUtils.hpp
+++ b/src/backends/aclCommon/ArmComputeTensorUtils.hpp
@@ -24,7 +24,7 @@ namespace armcomputetensorutils
{
/// Utility function to map an armnn::DataType to corresponding arm_compute::DataType.
-arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType);
+arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType, bool multiScales);
/// Utility function used to set up an arm_compute::Coordinates from a vector of ArmNN Axes for reduction functions
arm_compute::Coordinates BuildArmComputeReductionCoordinates(size_t inputDimensions,
diff --git a/src/backends/backendsCommon/LayerSupportRules.hpp b/src/backends/backendsCommon/LayerSupportRules.hpp
index d8b6af8a30..3a2ae06f5a 100644
--- a/src/backends/backendsCommon/LayerSupportRules.hpp
+++ b/src/backends/backendsCommon/LayerSupportRules.hpp
@@ -106,6 +106,14 @@ struct TypeIs : public Rule
}
};
+struct TypeNotPerAxisQuantized : public Rule
+{
+ TypeNotPerAxisQuantized(const TensorInfo& info)
+ {
+ m_Res = !info.IsQuantized() || !info.HasPerAxisQuantization();
+ }
+};
+
struct BiasAndWeightsTypesMatch : public Rule
{
BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index d2ab41ef40..075884b2da 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -149,6 +149,19 @@ void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
}
}
+void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
+{
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
+ if (tensor.GetDataType() != DataType::QSymmS8 &&
+ tensor.GetDataType() != DataType::QuantizedSymm8PerAxis)
+ {
+ throw InvalidArgumentException(descName +
+ ": Expected data type which supports per-axis quantization scheme but got " +
+ GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
+ }
+ ARMNN_NO_DEPRECATE_WARN_END
+}
+
//---------------------------------------------------------------
void ValidateTensorQuantizationSpace(const TensorInfo& first,
const TensorInfo& second,
@@ -344,11 +357,14 @@ void ValidateWeightDataType(const TensorInfo& inputInfo,
const DataType inputType = inputInfo.GetDataType();
if (inputType == DataType::QAsymmU8)
{
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
const std::vector<DataType> validTypes =
{
DataType::QAsymmU8,
- DataType::QuantizedSymm8PerAxis
+ DataType::QSymmS8,
+ DataType::QuantizedSymm8PerAxis // deprecated
};
+ ARMNN_NO_DEPRECATE_WARN_END
ValidateDataTypes(weightInfo, validTypes, descName);
}
@@ -412,7 +428,8 @@ void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
"but data type does not support per-axis quantization.") % descName % "weight"));
}
- ValidateTensorDataType(weightInfo, DataType::QuantizedSymm8PerAxis, descName, "weight");
+
+ ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
diff --git a/src/backends/backendsCommon/WorkloadUtils.cpp b/src/backends/backendsCommon/WorkloadUtils.cpp
index cb1f7c117a..69a62914e5 100644
--- a/src/backends/backendsCommon/WorkloadUtils.cpp
+++ b/src/backends/backendsCommon/WorkloadUtils.cpp
@@ -5,6 +5,8 @@
#include <backendsCommon/WorkloadUtils.hpp>
+#include <armnn/Utils.hpp>
+
namespace armnn
{
@@ -167,9 +169,13 @@ armnn::ConstTensor ConvertWeightTensorFromArmnnToAcl(const ConstCpuTensorHandle*
case DataType::QAsymmU8:
weightPermuted = ReorderWeightChannelsForAcl<uint8_t>(weightPermuted, dataLayout, permuteBuffer);
break;
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
case DataType::QuantizedSymm8PerAxis:
+ ARMNN_FALLTHROUGH;
+ case DataType::QSymmS8:
weightPermuted = ReorderWeightChannelsForAcl<int8_t>(weightPermuted, dataLayout, permuteBuffer);
break;
+ ARMNN_NO_DEPRECATE_WARN_END
default:
break;
}
diff --git a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
index 3c47eab01f..5c60e9e552 100644
--- a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
+++ b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
@@ -616,7 +616,7 @@ BOOST_AUTO_TEST_CASE(BiasPerAxisQuantization_Validate)
const TensorShape biasShape { cOutput };
constexpr DataType inputType = DataType::QAsymmU8;
- constexpr DataType weightType = DataType::QuantizedSymm8PerAxis;
+ constexpr DataType weightType = DataType::QSymmS8;
constexpr DataType biasType = DataType::Signed32;
constexpr float perTensorScale = 1.5f;
diff --git a/src/backends/backendsCommon/test/layerTests/Conv2dTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/Conv2dTestImpl.cpp
index b0b2981d8d..669398fb54 100644
--- a/src/backends/backendsCommon/test/layerTests/Conv2dTestImpl.cpp
+++ b/src/backends/backendsCommon/test/layerTests/Conv2dTestImpl.cpp
@@ -3049,7 +3049,7 @@ LayerTestResult<uint8_t, 4> Convolution2dPerAxisQuantTest(
using namespace armnn;
const DataType inputType = DataType::QAsymmU8;
- const DataType kernelType = DataType::QuantizedSymm8PerAxis;
+ const DataType kernelType = DataType::QSymmS8;
const DataType biasType = DataType::Signed32;
TensorInfo inputInfo ({ 1, 3, 1, 2 }, inputType, 0.5f, 128);
@@ -3273,7 +3273,7 @@ LayerTestResult<uint8_t, 4> DepthwiseConvolution2dPerAxisQuantTest(
using namespace armnn;
const DataType inputType = DataType::QAsymmU8;
- const DataType kernelType = DataType::QuantizedSymm8PerAxis;
+ const DataType kernelType = DataType::QSymmS8;
const DataType biasType = DataType::Signed32;
TensorInfo inputInfo ({ 1, 3, 3, 2 }, inputType, 0.5f, 128); // N H W C
diff --git a/src/backends/backendsCommon/test/layerTests/TransposeConvolution2dTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/TransposeConvolution2dTestImpl.cpp
index 1c880752c8..378ec46bd1 100644
--- a/src/backends/backendsCommon/test/layerTests/TransposeConvolution2dTestImpl.cpp
+++ b/src/backends/backendsCommon/test/layerTests/TransposeConvolution2dTestImpl.cpp
@@ -566,7 +566,7 @@ LayerTestResult<uint8_t, 4> TransposeConvolution2dPerAxisQuantTest(
using namespace armnn;
const DataType inputType = DataType::QAsymmU8;
- const DataType kernelType = DataType::QuantizedSymm8PerAxis;
+ const DataType kernelType = DataType::QSymmS8;
const DataType biasType = DataType::Signed32;
TensorInfo inputInfo ({ 1, 1, 2, 2 }, inputType, 0.50f, 10);
diff --git a/src/backends/cl/workloads/ClWorkloadUtils.hpp b/src/backends/cl/workloads/ClWorkloadUtils.hpp
index c5cfcd8fc1..709300681c 100644
--- a/src/backends/cl/workloads/ClWorkloadUtils.hpp
+++ b/src/backends/cl/workloads/ClWorkloadUtils.hpp
@@ -10,6 +10,8 @@
#include <cl/OpenClTimer.hpp>
#include <backendsCommon/CpuTensorHandle.hpp>
+#include <armnn/Utils.hpp>
+
#include <arm_compute/runtime/CL/CLFunctions.h>
#include <sstream>
@@ -101,9 +103,13 @@ inline void InitializeArmComputeClTensorData(arm_compute::CLTensor& clTensor,
case DataType::QAsymmU8:
CopyArmComputeClTensorData(clTensor, handle->GetConstTensor<uint8_t>());
break;
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
case DataType::QuantizedSymm8PerAxis:
+ ARMNN_FALLTHROUGH;
+ case DataType::QSymmS8:
CopyArmComputeClTensorData(clTensor, handle->GetConstTensor<int8_t>());
break;
+ ARMNN_NO_DEPRECATE_WARN_END
case DataType::Signed32:
CopyArmComputeClTensorData(clTensor, handle->GetConstTensor<int32_t>());
break;
diff --git a/src/backends/neon/workloads/NeonWorkloadUtils.hpp b/src/backends/neon/workloads/NeonWorkloadUtils.hpp
index f98fe44039..3f0fe842aa 100644
--- a/src/backends/neon/workloads/NeonWorkloadUtils.hpp
+++ b/src/backends/neon/workloads/NeonWorkloadUtils.hpp
@@ -10,6 +10,8 @@
#include <neon/NeonTimer.hpp>
#include <backendsCommon/CpuTensorHandle.hpp>
+#include <armnn/Utils.hpp>
+
#include <Half.hpp>
#define ARMNN_SCOPED_PROFILING_EVENT_NEON(name) \
@@ -46,9 +48,13 @@ inline void InitializeArmComputeTensorData(arm_compute::Tensor& tensor,
case DataType::QAsymmU8:
CopyArmComputeTensorData(tensor, handle->GetConstTensor<uint8_t>());
break;
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
case DataType::QuantizedSymm8PerAxis:
+ ARMNN_FALLTHROUGH;
+ case DataType::QSymmS8:
CopyArmComputeTensorData(tensor, handle->GetConstTensor<int8_t>());
break;
+ ARMNN_NO_DEPRECATE_WARN_END
case DataType::Signed32:
CopyArmComputeTensorData(tensor, handle->GetConstTensor<int32_t>());
break;
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 491081dbac..ee6462dfa3 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -437,11 +437,14 @@ bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
const DataType inputType = input.GetDataType();
if (inputType == DataType::QAsymmU8)
{
- std::array<DataType, 2> supportedWeightTypes =
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
+ std::array<DataType, 3> supportedWeightTypes =
{
DataType::QAsymmU8,
- DataType::QuantizedSymm8PerAxis
+ DataType::QSymmS8,
+ DataType::QuantizedSymm8PerAxis // deprecated
};
+ ARMNN_NO_DEPRECATE_WARN_END
supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
"Reference convolution2d: weights type not supported for quantized input.");
@@ -554,14 +557,18 @@ bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
"Reference DepthwiseConvolution2d: input and output types mismatched.");
- const DataType inputType = input.GetDataType();
- if (inputType == DataType::QAsymmU8)
- {
- std::array<DataType, 2> supportedWeightTypes =
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
+ std::array<DataType, 3> supportedWeightTypes =
{
DataType::QAsymmU8,
- DataType::QuantizedSymm8PerAxis
+ DataType::QSymmS8,
+ DataType::QuantizedSymm8PerAxis // deprecated
};
+ ARMNN_NO_DEPRECATE_WARN_END
+
+ const DataType inputType = input.GetDataType();
+ if (inputType == DataType::QAsymmU8)
+ {
supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
"Reference convolution2d: weights type not supported for quantized input.");
@@ -607,6 +614,9 @@ bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
"Reference dequantize: input type not supported.");
+ supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
+ "Reference dequantize: per-axis quantized input not support .");
+
std::array<DataType,2> supportedOutputTypes = {
DataType::Float32,
DataType::Float16
@@ -1836,11 +1846,14 @@ bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
const DataType inputType = input.GetDataType();
if (inputType == DataType::QAsymmU8)
{
- std::array<DataType, 2> supportedWeightTypes =
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
+ std::array<DataType, 3> supportedWeightTypes =
{
DataType::QAsymmU8,
- DataType::QuantizedSymm8PerAxis
+ DataType::QSymmS8,
+ DataType::QuantizedSymm8PerAxis //Deprecated
};
+ ARMNN_NO_DEPRECATE_WARN_END
supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
"Reference TransposeConvolution2d: weights type not supported for "
diff --git a/src/backends/reference/workloads/Decoders.hpp b/src/backends/reference/workloads/Decoders.hpp
index faabdcdb3f..6f309787bd 100644
--- a/src/backends/reference/workloads/Decoders.hpp
+++ b/src/backends/reference/workloads/Decoders.hpp
@@ -71,6 +71,7 @@ inline std::unique_ptr<Decoder<float>> MakeDecoder(const TensorInfo& info, const
{
switch(info.GetDataType())
{
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
case armnn::DataType::QuantizedSymm8PerAxis:
{
std::pair<unsigned int, std::vector<float>> params = armnnUtils::GetPerAxisParams(info);
@@ -79,6 +80,7 @@ inline std::unique_ptr<Decoder<float>> MakeDecoder(const TensorInfo& info, const
params.second,
params.first);
}
+ ARMNN_NO_DEPRECATE_WARN_END
case DataType::QAsymmU8:
{
return std::make_unique<QASymm8Decoder>(
@@ -107,10 +109,21 @@ inline std::unique_ptr<Decoder<float>> MakeDecoder(const TensorInfo& info, const
}
case DataType::QSymmS8:
{
- return std::make_unique<QSymmS8Decoder>(
- static_cast<const int8_t*>(data),
- info.GetQuantizationScale(),
- info.GetQuantizationOffset());
+ if (info.HasPerAxisQuantization())
+ {
+ std::pair<unsigned int, std::vector<float>> params = armnnUtils::GetPerAxisParams(info);
+ return std::make_unique<QSymm8PerAxisDecoder>(
+ static_cast<const int8_t*>(data),
+ params.second,
+ params.first);
+ }
+ else
+ {
+ return std::make_unique<QSymmS8Decoder>(
+ static_cast<const int8_t*>(data),
+ info.GetQuantizationScale(),
+ info.GetQuantizationOffset());
+ }
}
default:
{
diff --git a/src/backends/reference/workloads/Encoders.hpp b/src/backends/reference/workloads/Encoders.hpp
index 4fe202f0bf..8ddd559448 100644
--- a/src/backends/reference/workloads/Encoders.hpp
+++ b/src/backends/reference/workloads/Encoders.hpp
@@ -22,6 +22,7 @@ inline std::unique_ptr<Encoder<float>> MakeEncoder(const TensorInfo& info, void*
{
switch(info.GetDataType())
{
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
case armnn::DataType::QuantizedSymm8PerAxis:
{
std::pair<unsigned int, std::vector<float>> params = armnnUtils::GetPerAxisParams(info);
@@ -30,6 +31,7 @@ inline std::unique_ptr<Encoder<float>> MakeEncoder(const TensorInfo& info, void*
params.second,
params.first);
}
+ ARMNN_NO_DEPRECATE_WARN_END
case armnn::DataType::QAsymmU8:
{
return std::make_unique<QASymm8Encoder>(
@@ -39,10 +41,21 @@ inline std::unique_ptr<Encoder<float>> MakeEncoder(const TensorInfo& info, void*
}
case DataType::QSymmS8:
{
- return std::make_unique<QSymmS8Encoder>(
- static_cast<int8_t*>(data),
- info.GetQuantizationScale(),
- info.GetQuantizationOffset());
+ if (info.HasPerAxisQuantization())
+ {
+ std::pair<unsigned int, std::vector<float>> params = armnnUtils::GetPerAxisParams(info);
+ return std::make_unique<QSymm8PerAxisEncoder>(
+ static_cast<int8_t*>(data),
+ params.second,
+ params.first);
+ }
+ else
+ {
+ return std::make_unique<QSymmS8Encoder>(
+ static_cast<int8_t*>(data),
+ info.GetQuantizationScale(),
+ info.GetQuantizationOffset());
+ }
}
case armnn::DataType::QSymmS16:
{