diff options
Diffstat (limited to 'tests/validation/reference/DequantizationLayer.cpp')
-rw-r--r-- | tests/validation/reference/DequantizationLayer.cpp | 74 |
1 files changed, 67 insertions, 7 deletions
diff --git a/tests/validation/reference/DequantizationLayer.cpp b/tests/validation/reference/DequantizationLayer.cpp index 286a609d79..d07371c883 100644 --- a/tests/validation/reference/DequantizationLayer.cpp +++ b/tests/validation/reference/DequantizationLayer.cpp @@ -23,6 +23,8 @@ */ #include "DequantizationLayer.h" +#include "Permute.h" + namespace arm_compute { namespace test @@ -31,24 +33,82 @@ namespace validation { namespace reference { -template <typename T> -SimpleTensor<T> dequantization_layer(const SimpleTensor<uint8_t> &src) +namespace +{ +template <typename TOut> +TOut dequantize(int8_t val, const UniformQuantizationInfo qinfo) +{ + return static_cast<TOut>(dequantize_qsymm8(val, qinfo)); +} +template <typename TOut> +TOut dequantize(uint8_t val, const UniformQuantizationInfo qinfo) +{ + return static_cast<TOut>(dequantize_qasymm8(val, qinfo)); +} + +template <typename TOut, typename TIn> +SimpleTensor<TOut> dequantization_layer_nchw(const SimpleTensor<TIn> &src) { - const DataType dst_data_type = std::is_same<T, float>::value ? DataType::F32 : DataType::F16; - const UniformQuantizationInfo &quantization_info = src.quantization_info().uniform(); + const DataType src_data_type = src.data_type(); + const DataType dst_data_type = std::is_same<TOut, float>::value ? DataType::F32 : DataType::F16; - SimpleTensor<T> dst{ src.shape(), dst_data_type }; + SimpleTensor<TOut> dst{ src.shape(), dst_data_type }; - for(int i = 0; i < src.num_elements(); ++i) + if(src_data_type == DataType::QSYMM8_PER_CHANNEL) { - dst[i] = static_cast<T>(dequantize_qasymm8(src[i], quantization_info)); + const int WH = src.shape().x() * src.shape().y(); + const int C = src.shape().z(); + const int N = src.shape().total_size() / (WH * C); + + const std::vector<float> qscales = src.quantization_info().scale(); + + for(int n = 0; n < N; ++n) + { + for(int c = 0; c < C; ++c) + { + const size_t idx = n * C * WH + c * WH; + const UniformQuantizationInfo channel_qinfo = { qscales[c], 0 }; + + // Dequantize slice + for(int s = 0; s < WH; ++s) + { + dst[idx + s] = dequantize<TOut>(src[idx + s], channel_qinfo); + } + } + } + } + else + { + const UniformQuantizationInfo &quantization_info = src.quantization_info().uniform(); + ARM_COMPUTE_ERROR_ON(quantization_info.offset != 0 && src_data_type == DataType::QSYMM8); + + for(int i = 0; i < src.num_elements(); ++i) + { + dst[i] = static_cast<TOut>(dequantize<TOut>(src[i], quantization_info)); + } } return dst; } +} // namespace +template <typename TOut, typename TIn> +SimpleTensor<TOut> dequantization_layer(const SimpleTensor<TIn> &src) +{ + if(src.data_layout() == DataLayout::NHWC && src.data_type() == DataType::QSYMM8_PER_CHANNEL) + { + SimpleTensor<TIn> src_nchw = reference::permute<TIn>(src, PermutationVector(1U, 2U, 0U)); + return reference::permute<TOut>(dequantization_layer_nchw<TOut>(src_nchw), PermutationVector(2U, 0U, 1U)); + } + else + { + return dequantization_layer_nchw<TOut>(src); + } +} template SimpleTensor<half> dequantization_layer(const SimpleTensor<uint8_t> &src); template SimpleTensor<float> dequantization_layer(const SimpleTensor<uint8_t> &src); +template SimpleTensor<half> dequantization_layer(const SimpleTensor<int8_t> &src); +template SimpleTensor<float> dequantization_layer(const SimpleTensor<int8_t> &src); } // namespace reference } // namespace validation } // namespace test |