diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2019-06-04 13:04:16 +0100 |
---|---|---|
committer | Georgios Pinitas <georgios.pinitas@arm.com> | 2019-06-24 14:56:23 +0000 |
commit | 3d13af8a39f408318328a95d5329bc17fd923438 (patch) | |
tree | b0d9c82062e229f8938d2c9f762ee67758196bf3 /tests/validation/reference | |
parent | db09b3783ff9af67c6d373b12aa9a6aff3c5d0f1 (diff) | |
download | ComputeLibrary-3d13af8a39f408318328a95d5329bc17fd923438.tar.gz |
COMPMID-2235: Extend type support for CL/NEON DequantizationLayer.
Adds support for:
- QSYMM8
Change-Id: Ia0b839fc844ce0f968dad1b69a001f9a660dbcd5
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1378
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Manuel Bottini <manuel.bottini@arm.com>
Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Diffstat (limited to 'tests/validation/reference')
-rw-r--r-- | tests/validation/reference/DequantizationLayer.cpp | 74 | ||||
-rw-r--r-- | tests/validation/reference/DequantizationLayer.h | 4 |
2 files changed, 69 insertions, 9 deletions
diff --git a/tests/validation/reference/DequantizationLayer.cpp b/tests/validation/reference/DequantizationLayer.cpp index 286a609d79..d07371c883 100644 --- a/tests/validation/reference/DequantizationLayer.cpp +++ b/tests/validation/reference/DequantizationLayer.cpp @@ -23,6 +23,8 @@ */ #include "DequantizationLayer.h" +#include "Permute.h" + namespace arm_compute { namespace test @@ -31,24 +33,82 @@ namespace validation { namespace reference { -template <typename T> -SimpleTensor<T> dequantization_layer(const SimpleTensor<uint8_t> &src) +namespace +{ +template <typename TOut> +TOut dequantize(int8_t val, const UniformQuantizationInfo qinfo) +{ + return static_cast<TOut>(dequantize_qsymm8(val, qinfo)); +} +template <typename TOut> +TOut dequantize(uint8_t val, const UniformQuantizationInfo qinfo) +{ + return static_cast<TOut>(dequantize_qasymm8(val, qinfo)); +} + +template <typename TOut, typename TIn> +SimpleTensor<TOut> dequantization_layer_nchw(const SimpleTensor<TIn> &src) { - const DataType dst_data_type = std::is_same<T, float>::value ? DataType::F32 : DataType::F16; - const UniformQuantizationInfo &quantization_info = src.quantization_info().uniform(); + const DataType src_data_type = src.data_type(); + const DataType dst_data_type = std::is_same<TOut, float>::value ? DataType::F32 : DataType::F16; - SimpleTensor<T> dst{ src.shape(), dst_data_type }; + SimpleTensor<TOut> dst{ src.shape(), dst_data_type }; - for(int i = 0; i < src.num_elements(); ++i) + if(src_data_type == DataType::QSYMM8_PER_CHANNEL) { - dst[i] = static_cast<T>(dequantize_qasymm8(src[i], quantization_info)); + const int WH = src.shape().x() * src.shape().y(); + const int C = src.shape().z(); + const int N = src.shape().total_size() / (WH * C); + + const std::vector<float> qscales = src.quantization_info().scale(); + + for(int n = 0; n < N; ++n) + { + for(int c = 0; c < C; ++c) + { + const size_t idx = n * C * WH + c * WH; + const UniformQuantizationInfo channel_qinfo = { qscales[c], 0 }; + + // Dequantize slice + for(int s = 0; s < WH; ++s) + { + dst[idx + s] = dequantize<TOut>(src[idx + s], channel_qinfo); + } + } + } + } + else + { + const UniformQuantizationInfo &quantization_info = src.quantization_info().uniform(); + ARM_COMPUTE_ERROR_ON(quantization_info.offset != 0 && src_data_type == DataType::QSYMM8); + + for(int i = 0; i < src.num_elements(); ++i) + { + dst[i] = static_cast<TOut>(dequantize<TOut>(src[i], quantization_info)); + } } return dst; } +} // namespace +template <typename TOut, typename TIn> +SimpleTensor<TOut> dequantization_layer(const SimpleTensor<TIn> &src) +{ + if(src.data_layout() == DataLayout::NHWC && src.data_type() == DataType::QSYMM8_PER_CHANNEL) + { + SimpleTensor<TIn> src_nchw = reference::permute<TIn>(src, PermutationVector(1U, 2U, 0U)); + return reference::permute<TOut>(dequantization_layer_nchw<TOut>(src_nchw), PermutationVector(2U, 0U, 1U)); + } + else + { + return dequantization_layer_nchw<TOut>(src); + } +} template SimpleTensor<half> dequantization_layer(const SimpleTensor<uint8_t> &src); template SimpleTensor<float> dequantization_layer(const SimpleTensor<uint8_t> &src); +template SimpleTensor<half> dequantization_layer(const SimpleTensor<int8_t> &src); +template SimpleTensor<float> dequantization_layer(const SimpleTensor<int8_t> &src); } // namespace reference } // namespace validation } // namespace test diff --git a/tests/validation/reference/DequantizationLayer.h b/tests/validation/reference/DequantizationLayer.h index 1d0e54b442..8c780849fd 100644 --- a/tests/validation/reference/DequantizationLayer.h +++ b/tests/validation/reference/DequantizationLayer.h @@ -35,8 +35,8 @@ namespace validation { namespace reference { -template <typename T> -SimpleTensor<T> dequantization_layer(const SimpleTensor<uint8_t> &src); +template <typename TOut, typename TIn> +SimpleTensor<TOut> dequantization_layer(const SimpleTensor<TIn> &src); } // namespace reference } // namespace validation } // namespace test |