diff options
Diffstat (limited to 'src/backends/reference/workloads')
-rw-r--r-- | src/backends/reference/workloads/Decoders.hpp | 21 | ||||
-rw-r--r-- | src/backends/reference/workloads/Encoders.hpp | 21 |
2 files changed, 34 insertions, 8 deletions
diff --git a/src/backends/reference/workloads/Decoders.hpp b/src/backends/reference/workloads/Decoders.hpp index faabdcdb3f..6f309787bd 100644 --- a/src/backends/reference/workloads/Decoders.hpp +++ b/src/backends/reference/workloads/Decoders.hpp @@ -71,6 +71,7 @@ inline std::unique_ptr<Decoder<float>> MakeDecoder(const TensorInfo& info, const { switch(info.GetDataType()) { + ARMNN_NO_DEPRECATE_WARN_BEGIN case armnn::DataType::QuantizedSymm8PerAxis: { std::pair<unsigned int, std::vector<float>> params = armnnUtils::GetPerAxisParams(info); @@ -79,6 +80,7 @@ inline std::unique_ptr<Decoder<float>> MakeDecoder(const TensorInfo& info, const params.second, params.first); } + ARMNN_NO_DEPRECATE_WARN_END case DataType::QAsymmU8: { return std::make_unique<QASymm8Decoder>( @@ -107,10 +109,21 @@ inline std::unique_ptr<Decoder<float>> MakeDecoder(const TensorInfo& info, const } case DataType::QSymmS8: { - return std::make_unique<QSymmS8Decoder>( - static_cast<const int8_t*>(data), - info.GetQuantizationScale(), - info.GetQuantizationOffset()); + if (info.HasPerAxisQuantization()) + { + std::pair<unsigned int, std::vector<float>> params = armnnUtils::GetPerAxisParams(info); + return std::make_unique<QSymm8PerAxisDecoder>( + static_cast<const int8_t*>(data), + params.second, + params.first); + } + else + { + return std::make_unique<QSymmS8Decoder>( + static_cast<const int8_t*>(data), + info.GetQuantizationScale(), + info.GetQuantizationOffset()); + } } default: { diff --git a/src/backends/reference/workloads/Encoders.hpp b/src/backends/reference/workloads/Encoders.hpp index 4fe202f0bf..8ddd559448 100644 --- a/src/backends/reference/workloads/Encoders.hpp +++ b/src/backends/reference/workloads/Encoders.hpp @@ -22,6 +22,7 @@ inline std::unique_ptr<Encoder<float>> MakeEncoder(const TensorInfo& info, void* { switch(info.GetDataType()) { + ARMNN_NO_DEPRECATE_WARN_BEGIN case armnn::DataType::QuantizedSymm8PerAxis: { std::pair<unsigned int, std::vector<float>> params = armnnUtils::GetPerAxisParams(info); @@ -30,6 +31,7 @@ inline std::unique_ptr<Encoder<float>> MakeEncoder(const TensorInfo& info, void* params.second, params.first); } + ARMNN_NO_DEPRECATE_WARN_END case armnn::DataType::QAsymmU8: { return std::make_unique<QASymm8Encoder>( @@ -39,10 +41,21 @@ inline std::unique_ptr<Encoder<float>> MakeEncoder(const TensorInfo& info, void* } case DataType::QSymmS8: { - return std::make_unique<QSymmS8Encoder>( - static_cast<int8_t*>(data), - info.GetQuantizationScale(), - info.GetQuantizationOffset()); + if (info.HasPerAxisQuantization()) + { + std::pair<unsigned int, std::vector<float>> params = armnnUtils::GetPerAxisParams(info); + return std::make_unique<QSymm8PerAxisEncoder>( + static_cast<int8_t*>(data), + params.second, + params.first); + } + else + { + return std::make_unique<QSymmS8Encoder>( + static_cast<int8_t*>(data), + info.GetQuantizationScale(), + info.GetQuantizationOffset()); + } } case armnn::DataType::QSymmS16: { |