diff options
author | Ramy Elgammal <ramy.elgammal@arm.com> | 2023-01-30 04:56:47 +0000 |
---|---|---|
committer | Ramy Elgammal <ramy.elgammal@arm.com> | 2023-03-17 12:45:26 +0000 |
commit | 14d7b535d48620f009efca576cc70fb6ea9ff20d (patch) | |
tree | 592ae31798825fba0b71b36adda008fb53381e9f /tests/validation/reference/ElementwiseUnary.cpp | |
parent | 2b6ebfe4270b06b09e45f306e8384950aeca7e4e (diff) | |
download | ComputeLibrary-14d7b535d48620f009efca576cc70fb6ea9ff20d.tar.gz |
Implementation of RSQRT for quantized int8
Resolves: COMPMID-5863
Change-Id: I9ff67face62826c1d335a6b941e8516be39bdac8
Signed-off-by: Ramy Elgammal <ramy.elgammal@arm.com>
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/c/VisualCompute/ComputeLibrary/+/488768
Tested-by: bsgcomp <bsgcomp@arm.com>
Comments-Addressed: bsgcomp <bsgcomp@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9225
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/reference/ElementwiseUnary.cpp')
-rw-r--r-- | tests/validation/reference/ElementwiseUnary.cpp | 85 |
1 files changed, 76 insertions, 9 deletions
diff --git a/tests/validation/reference/ElementwiseUnary.cpp b/tests/validation/reference/ElementwiseUnary.cpp index 5333b53c15..d5218d772d 100644 --- a/tests/validation/reference/ElementwiseUnary.cpp +++ b/tests/validation/reference/ElementwiseUnary.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2020, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -22,7 +22,8 @@ * SOFTWARE. */ #include "ElementwiseUnary.h" - +#include "tests/validation/Helpers.h" +#include "utils/TypePrinter.h" namespace arm_compute { namespace test @@ -32,10 +33,8 @@ namespace validation namespace reference { template <typename T> -SimpleTensor<T> elementwise_unary(const SimpleTensor<T> &src, ElementWiseUnary op) +SimpleTensor<T> elementwise_unary(const SimpleTensor<T> &src, SimpleTensor<T> &dst, ElementWiseUnary op) { - SimpleTensor<T> dst(src.shape(), src.data_type()); - for(int i = 0; i < src.num_elements(); ++i) { switch(op) @@ -65,13 +64,81 @@ SimpleTensor<T> elementwise_unary(const SimpleTensor<T> &src, ElementWiseUnary o ARM_COMPUTE_ERROR("Not implemented"); } } - return dst; } +template <> +SimpleTensor<int8_t> elementwise_unary(const SimpleTensor<int8_t> &src, SimpleTensor<int8_t> &dst, ElementWiseUnary op) +{ + if(dst.data_type() == DataType::QASYMM8_SIGNED) + { + SimpleTensor<float> src_tmp = convert_from_asymmetric(src); + SimpleTensor<float> dst_tmp(src.shape(), DataType::F32); + for(int i = 0; i < src.num_elements(); ++i) + { + switch(op) + { + case ElementWiseUnary::RSQRT: + if(src_tmp[i] != 0) + { + dst_tmp[i] = 1.f / std::sqrt(src_tmp[i]); + } + else + { + // rsqrt(0) give 'inf' so set to the maximum in int8: 127 + dst_tmp[i] = (127.0f - dst.quantization_info().uniform().offset) * dst.quantization_info().uniform().scale ; + } + break; + default: + ARM_COMPUTE_ERROR("Not implemented"); + } + } + dst = convert_to_asymmetric<int8_t>(dst_tmp, dst.quantization_info()); + } + else + { + ARM_COMPUTE_ERROR("Not implemented"); + } + return dst; +} +template <> +SimpleTensor<uint8_t> elementwise_unary(const SimpleTensor<uint8_t> &src, SimpleTensor<uint8_t> &dst, ElementWiseUnary op) +{ + if(dst.data_type() == DataType::QASYMM8) + { + SimpleTensor<float> src_tmp = convert_from_asymmetric(src); + SimpleTensor<float> dst_tmp(src.shape(), DataType::F32); + for(int i = 0; i < src.num_elements(); ++i) + { + switch(op) + { + case ElementWiseUnary::RSQRT: + if(src_tmp[i] != 0) + { + dst_tmp[i] = 1.f / std::sqrt(src_tmp[i]); + } + else + { + // rsqrt(0) give 'inf' so set to the maximum in uint8: 255 + dst_tmp[i] = (255.0f - dst.quantization_info().uniform().offset)* dst.quantization_info().uniform().scale; + } + break; + default: + ARM_COMPUTE_ERROR("Not implemented"); + } + } + dst = convert_to_asymmetric<uint8_t>(dst_tmp, dst.quantization_info()); + } + else + { + ARM_COMPUTE_ERROR("Not implemented"); + } + return dst; +} + +template SimpleTensor<float> elementwise_unary(const SimpleTensor<float> &src, SimpleTensor<float> &dst, ElementWiseUnary op); +template SimpleTensor<half> elementwise_unary(const SimpleTensor<half> &src, SimpleTensor<half> &dst, ElementWiseUnary op); +template SimpleTensor<int32_t> elementwise_unary(const SimpleTensor<int32_t> &src, SimpleTensor<int32_t> &dst, ElementWiseUnary op); -template SimpleTensor<float> elementwise_unary(const SimpleTensor<float> &src, ElementWiseUnary op); -template SimpleTensor<half> elementwise_unary(const SimpleTensor<half> &src, ElementWiseUnary op); -template SimpleTensor<int32_t> elementwise_unary(const SimpleTensor<int32_t> &src, ElementWiseUnary op); } // namespace reference } // namespace validation } // namespace test |