diff options
-rw-r--r-- | Android.bp | 1 | ||||
-rw-r--r-- | SConscript | 1 | ||||
-rw-r--r-- | src/core/CL/cl_kernels/common/elementwise_unary_quantized.cl | 77 | ||||
-rw-r--r-- | src/gpu/cl/ClKernelLibrary.cpp | 5 | ||||
-rw-r--r-- | src/gpu/cl/kernels/ClElementwiseUnaryKernel.cpp | 41 | ||||
-rw-r--r-- | tests/validation/CL/RsqrtLayer.cpp | 35 | ||||
-rw-r--r-- | tests/validation/fixtures/ElementwiseUnaryFixture.h | 44 | ||||
-rw-r--r-- | tests/validation/reference/ElementwiseUnary.cpp | 85 | ||||
-rw-r--r-- | tests/validation/reference/ElementwiseUnary.h | 4 |
9 files changed, 253 insertions, 40 deletions
diff --git a/Android.bp b/Android.bp index 42116d78ca..848ae1049b 100644 --- a/Android.bp +++ b/Android.bp @@ -27,6 +27,7 @@ opencl_srcs = [ "src/core/CL/cl_kernels/common/elementwise_operation.cl", "src/core/CL/cl_kernels/common/elementwise_operation_quantized.cl", "src/core/CL/cl_kernels/common/elementwise_unary.cl", + "src/core/CL/cl_kernels/common/elementwise_unary_quantized.cl", "src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/fp_post_ops_act_eltwise_op_act.h", "src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_native.cl", "src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped.cl", diff --git a/SConscript b/SConscript index 205764b9a7..03b94b6bd1 100644 --- a/SConscript +++ b/SConscript @@ -369,6 +369,7 @@ if env['opencl'] and env['embed_kernels']: 'src/core/CL/cl_kernels/common/elementwise_operation.cl', 'src/core/CL/cl_kernels/common/elementwise_operation_quantized.cl', 'src/core/CL/cl_kernels/common/elementwise_unary.cl', + 'src/core/CL/cl_kernels/common/elementwise_unary_quantized.cl', 'src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_native.cl', 'src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped.cl', 'src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped_only_rhs.cl', diff --git a/src/core/CL/cl_kernels/common/elementwise_unary_quantized.cl b/src/core/CL/cl_kernels/common/elementwise_unary_quantized.cl new file mode 100644 index 0000000000..2e4cdc53fe --- /dev/null +++ b/src/core/CL/cl_kernels/common/elementwise_unary_quantized.cl @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "helpers.h" + +#if defined(DATA_TYPE) && defined(OPERATION) +// Calculate reverse square root +#define rsqrt_op(input) rsqrt(input) +#if defined(VEC_SIZE) +#define VEC_FLOAT VEC_DATA_TYPE(float, VEC_SIZE) +#define VEC_INT VEC_DATA_TYPE(int, VEC_SIZE) +#define VEC_TYPE VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) +#endif // defined(VEC_SIZE) + +/** Applies element wise unary operator in a tensor. + * + * @param[in] in_ptr Pointer to the source image. Supported data types: QASYMM8/QASYMM8_SIGNED. + * @param[in] in_stride_x Stride of the source tensor in X dimension (in bytes) + * @param[in] in_step_x in_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] in_stride_y Stride of the source tensor in Y dimension (in bytes) + * @param[in] in_step_y in_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] in_stride_z Stride of the source tensor in Z dimension (in bytes) + * @param[in] in_step_z in_stride_z * number of elements along Z processed per workitem(in bytes) + * @param[in] in_offset_first_element_in_bytes Offset of the first element in the source image + * @param[out] out_ptr Pointer to the destination image. Supported data types: QASYMM8/QASYMM8_SIGNED. + * @param[in] out_stride_x Stride of the destination image in X dimension (in bytes) + * @param[in] out_step_x out_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] out_step_y Stride of the destination tensor in Y dimension (in bytes) + * @param[in] out_step_y out_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] out_stride_z Stride of the destination tensor in Z dimension (in bytes) + * @param[in] out_step_z out_stride_z * number of elements along Z processed per workitem(in bytes) + * @param[in] out_offset_first_element_in_bytes Offset of the first element in the destination image + */ +__kernel void elementwise_unary_quantized( + TENSOR3D_DECLARATION(in), + TENSOR3D_DECLARATION(out)) +{ + Tensor3D in = CONVERT_TO_TENSOR3D_STRUCT(in); + Tensor3D out = CONVERT_TO_TENSOR3D_STRUCT(out); + + // Check if access on width gets out of bounds + // If it does shift access vector to access elements within bounds + const int xi = (int)(get_global_id(0) * VEC_SIZE); + in.ptr -= max(xi - (int)LAST_ACCESSED_X, 0) * in_stride_x; + out.ptr -= max(xi - (int)LAST_ACCESSED_X, 0) * out_stride_x; + + VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) + data = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)in.ptr); + VEC_DATA_TYPE(float, VEC_SIZE) + data_f32 = CONVERT(data, VEC_FLOAT); + data_f32 = (data_f32 - (float)OFFSET_IN) * (float)SCALE_IN; + VEC_INT qres_int = CONVERT_SAT((OPERATION(data_f32) / ((VEC_FLOAT)(float)SCALE_OUT)), VEC_INT) + ((VEC_INT)((int)OFFSET_OUT)); + const VEC_TYPE qres = CONVERT_SAT(qres_int, VEC_TYPE); + VSTORE(VEC_SIZE) + (qres, 0, (__global DATA_TYPE *)out.ptr); +} +#endif // defined(DATA_TYPE) && defined(OPERATION) diff --git a/src/gpu/cl/ClKernelLibrary.cpp b/src/gpu/cl/ClKernelLibrary.cpp index 482e8c341d..8099071fcd 100644 --- a/src/gpu/cl/ClKernelLibrary.cpp +++ b/src/gpu/cl/ClKernelLibrary.cpp @@ -235,6 +235,7 @@ const std::map<std::string, std::string> ClKernelLibrary::_kernel_program_map = { "elementwise_operation_SQUARED_DIFF_quantized", "common/elementwise_operation_quantized.cl" }, { "elementwise_operation_PRELU_quantized", "common/elementwise_operation_quantized.cl" }, { "elementwise_unary", "common/elementwise_unary.cl" }, + { "elementwise_unary_quantized", "common/elementwise_unary_quantized.cl" }, { "fft_digit_reverse_axis_0", "common/fft_digit_reverse.cl" }, { "fft_digit_reverse_axis_1", "common/fft_digit_reverse.cl" }, { "fft_radix_2_first_stage_axis_0", "common/fft.cl" }, @@ -572,6 +573,10 @@ const std::map<std::string, std::string> ClKernelLibrary::_program_source_map = #include "./cl_kernels/common/elementwise_unary.clembed" }, { + "common/elementwise_unary_quantized.cl", +#include "./cl_kernels/common/elementwise_unary_quantized.clembed" + }, + { "common/fft.cl", #include "./cl_kernels/common/fft.clembed" }, diff --git a/src/gpu/cl/kernels/ClElementwiseUnaryKernel.cpp b/src/gpu/cl/kernels/ClElementwiseUnaryKernel.cpp index 1f09515b86..40b1eaca1f 100644 --- a/src/gpu/cl/kernels/ClElementwiseUnaryKernel.cpp +++ b/src/gpu/cl/kernels/ClElementwiseUnaryKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021 Arm Limited. + * Copyright (c) 2018-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -38,6 +38,8 @@ namespace kernels { namespace { +constexpr unsigned int vector_size_byte_opencl = 16; + Status validate_arguments(const ITensorInfo &src, const ITensorInfo &dst, const ElementWiseUnary op) { ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(&src); @@ -49,6 +51,10 @@ Status validate_arguments(const ITensorInfo &src, const ITensorInfo &dst, const { ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src, 1, DataType::F16, DataType::F32, DataType::S32); } + else if(op == ElementWiseUnary::RSQRT) // Allow quantized types for only RSQRT. + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src, 1, DataType::F16, DataType::F32, DataType::QASYMM8, DataType::QASYMM8_SIGNED); + } else { ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src, 1, DataType::F16, DataType::F32); @@ -78,17 +84,29 @@ void ClElementWiseUnaryKernel::configure(const CLCompileContext &compile_context auto padding_info = get_padding_info({ src, dst }); ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*src, *dst, op)); + const unsigned int num_elems_processed_per_iteration = adjust_vec_size(vector_size_byte_opencl / dst->element_size(), dst->dimension(0)); - const std::string kernel_name = "elementwise_unary"; - const int vec_size_x = 16 / dst->element_size(); - const int dst_width_x = dst->tensor_shape().x(); - const bool multi_access_x = (dst_width_x / vec_size_x > 0); - + std::string kernel_name = "elementwise_unary"; + const int vec_size_x = num_elems_processed_per_iteration; + const int dst_width_x = dst->dimension(0); + if(is_data_type_quantized(src->data_type())) + { + kernel_name += "_quantized"; + } // Set kernel build options CLBuildOptions build_opts; build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(src->data_type())); - build_opts.add_option_if(multi_access_x, "-DVEC_SIZE=" + support::cpp11::to_string(vec_size_x)); - build_opts.add_option_if(multi_access_x, "-DLAST_ACCESSED_X=" + support::cpp11::to_string(std::max<int>(dst_width_x - vec_size_x, 0))); + build_opts.add_option("-DVEC_SIZE=" + support::cpp11::to_string(vec_size_x)); + build_opts.add_option("-DLAST_ACCESSED_X=" + support::cpp11::to_string(std::max<int>(dst_width_x - vec_size_x, 0))); + if(is_data_type_quantized(src->data_type())) + { + const UniformQuantizationInfo iqinfo = src->quantization_info().uniform(); + const UniformQuantizationInfo oqinfo = dst->quantization_info().uniform(); + build_opts.add_option("-DOFFSET_IN=" + support::cpp11::to_string(iqinfo.offset)); + build_opts.add_option("-DOFFSET_OUT=" + support::cpp11::to_string(oqinfo.offset)); + build_opts.add_option("-DSCALE_IN=" + float_to_string_with_full_precision(iqinfo.scale)); + build_opts.add_option("-DSCALE_OUT=" + float_to_string_with_full_precision(oqinfo.scale)); + } switch(op) { case ElementWiseUnary::RSQRT: @@ -124,11 +142,8 @@ void ClElementWiseUnaryKernel::configure(const CLCompileContext &compile_context // Configure kernel window Window win = calculate_max_window(*dst); - if(multi_access_x) - { - win.set(Window::DimX, - Window::Dimension(win.x().start(), ceil_to_multiple(win.x().end(), vec_size_x), vec_size_x)); - } + win.set(Window::DimX, Window::Dimension(win.x().start(), ceil_to_multiple(win.x().end(), vec_size_x), vec_size_x)); + ICLKernel::configure_internal(win); ARM_COMPUTE_ERROR_ON(has_padding_changed(padding_info)); diff --git a/tests/validation/CL/RsqrtLayer.cpp b/tests/validation/CL/RsqrtLayer.cpp index 936d853d34..2353bda8d3 100644 --- a/tests/validation/CL/RsqrtLayer.cpp +++ b/tests/validation/CL/RsqrtLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 Arm Limited. + * Copyright (c) 2018-2019, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -42,8 +42,11 @@ namespace validation { namespace { -RelativeTolerance<float> tolerance_fp32(0.000001f); -RelativeTolerance<float> tolerance_fp16(0.001f); +RelativeTolerance<float> tolerance_fp32(0.000001f); +RelativeTolerance<float> tolerance_fp16(0.001f); +constexpr AbsoluteTolerance<uint8_t> tolerance_qasymm8(1); /**< Tolerance value for comparing reference's output against implementation's output for unsigned 8-bit asymmetric type */ +constexpr AbsoluteTolerance<int8_t> tolerance_qasymm8_s(1); /**< Tolerance value for comparing reference's output against implementation's output for signed 8-bit asymmetric type */ + } // namespace TEST_SUITE(CL) TEST_SUITE(RsqrtLayer) @@ -68,6 +71,8 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip( // *INDENT-ON* template <typename T> using CLRsqrtLayerFixture = RsqrtValidationFixture<CLTensor, CLAccessor, CLRsqrtLayer, T>; +template <typename T> +using CLRsqrtLayerQuantizedFixture = RsqrtQuantizedValidationFixture<CLTensor, CLAccessor, CLRsqrtLayer, T>; TEST_SUITE(Float) TEST_SUITE(FP16) @@ -102,6 +107,30 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLRsqrtLayerFixture<float>, framework::DatasetM TEST_SUITE_END() // FP32 TEST_SUITE_END() // Float +TEST_SUITE(Quantized) +TEST_SUITE(QASYMM8_SIGNED) +FIXTURE_DATA_TEST_CASE(RunSmall, CLRsqrtLayerQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", + DataType::QASYMM8_SIGNED)), + framework::dataset::make("SrcQInfo", { QuantizationInfo(0.4044, -128) })), + framework::dataset::make("OutQInfo", { QuantizationInfo(0.0027, -128) }))) +{ + // Validate output + validate(CLAccessor(_target), _reference, tolerance_qasymm8_s); +} +TEST_SUITE_END() // QASYMM8_SIGNED +TEST_SUITE(QASYMM8) + +FIXTURE_DATA_TEST_CASE(RunSmall, CLRsqrtLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", + DataType::QASYMM8)), + framework::dataset::make("SrcQInfo", { QuantizationInfo(0.4044, 0) })), + framework::dataset::make("OutQInfo", { QuantizationInfo(0.0027, 0) }))) +{ + // Validate output + validate(CLAccessor(_target), _reference, tolerance_qasymm8); +} +TEST_SUITE_END() // QASYMM8 +TEST_SUITE_END() // Quantized + TEST_SUITE_END() // RsqrtLayer TEST_SUITE_END() // CL } // namespace validation diff --git a/tests/validation/fixtures/ElementwiseUnaryFixture.h b/tests/validation/fixtures/ElementwiseUnaryFixture.h index 7221226fd1..1dc4f03e99 100644 --- a/tests/validation/fixtures/ElementwiseUnaryFixture.h +++ b/tests/validation/fixtures/ElementwiseUnaryFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021 Arm Limited. + * Copyright (c) 2018-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -44,11 +44,12 @@ class ElementWiseUnaryValidationFixture : public framework::Fixture { public: template <typename...> - void setup(TensorShape input_shape, DataType input_data_type, bool in_place, ElementWiseUnary op, bool use_dynamic_shape = false) + void setup(TensorShape input_shape, DataType input_data_type, bool in_place, ElementWiseUnary op, + bool use_dynamic_shape = false, QuantizationInfo qinfo = QuantizationInfo(), QuantizationInfo qinfo_out = QuantizationInfo()) { _op = op; - _target = compute_target(input_shape, input_data_type, in_place); - _reference = compute_reference(input_shape, input_data_type); + _target = compute_target(input_shape, input_data_type, in_place, qinfo, qinfo_out); + _reference = compute_reference(input_shape, input_data_type, qinfo, qinfo_out); _use_dynamic_shape = use_dynamic_shape; } @@ -69,8 +70,15 @@ protected: } case ElementWiseUnary::RSQRT: { - FloatDistributionType distribution{ FloatType(1.0f), FloatType(2.0f) }; - library->fill(tensor, distribution, i); + if(data_type == DataType::F32 || data_type == DataType::F16) + { + FloatDistributionType distribution{ FloatType(1.0f), FloatType(2.0f) }; + library->fill(tensor, distribution, i); + } + else + { + library->fill_tensor_uniform(tensor, i); + } break; } case ElementWiseUnary::ABS: @@ -124,12 +132,11 @@ protected: } } - TensorType compute_target(const TensorShape &shape, DataType data_type, bool in_place) + TensorType compute_target(const TensorShape &shape, DataType data_type, bool in_place, QuantizationInfo qinfo, QuantizationInfo qinfo_out) { // Create tensors - TensorType src = create_tensor<TensorType>(shape, data_type); - TensorType dst = create_tensor<TensorType>(shape, data_type); - + TensorType src = create_tensor<TensorType>(shape, data_type, 1, qinfo); + TensorType dst = create_tensor<TensorType>(shape, data_type, 1, qinfo_out); TensorType *actual_dst = in_place ? &src : &dst; // if _use_dynamic_shape is true, this fixture will test scenario for dynamic shapes. @@ -176,15 +183,16 @@ protected: } } - SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type) + SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type, QuantizationInfo qinfo, QuantizationInfo qinfo_out) { // Create reference - SimpleTensor<T> src{ shape, data_type }; + SimpleTensor<T> src{ shape, data_type, 1, qinfo }; + SimpleTensor<T> dst{ shape, data_type, 1, qinfo_out }; // Fill reference fill(src, 0, data_type); - return reference::elementwise_unary<T>(src, _op); + return reference::elementwise_unary<T>(src, dst, _op); } TensorType _target{}; @@ -192,6 +200,16 @@ protected: ElementWiseUnary _op{}; bool _use_dynamic_shape{ false }; }; +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +class RsqrtQuantizedValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T> +{ +public: + template <typename...> + void setup(const TensorShape &shape, DataType data_type, QuantizationInfo qinfo, QuantizationInfo qinfo_out) + { + ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::RSQRT, false, qinfo, qinfo_out); + } +}; template <typename TensorType, typename AccessorType, typename FunctionType, typename T> class RsqrtValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T> diff --git a/tests/validation/reference/ElementwiseUnary.cpp b/tests/validation/reference/ElementwiseUnary.cpp index 5333b53c15..d5218d772d 100644 --- a/tests/validation/reference/ElementwiseUnary.cpp +++ b/tests/validation/reference/ElementwiseUnary.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2020, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -22,7 +22,8 @@ * SOFTWARE. */ #include "ElementwiseUnary.h" - +#include "tests/validation/Helpers.h" +#include "utils/TypePrinter.h" namespace arm_compute { namespace test @@ -32,10 +33,8 @@ namespace validation namespace reference { template <typename T> -SimpleTensor<T> elementwise_unary(const SimpleTensor<T> &src, ElementWiseUnary op) +SimpleTensor<T> elementwise_unary(const SimpleTensor<T> &src, SimpleTensor<T> &dst, ElementWiseUnary op) { - SimpleTensor<T> dst(src.shape(), src.data_type()); - for(int i = 0; i < src.num_elements(); ++i) { switch(op) @@ -65,13 +64,81 @@ SimpleTensor<T> elementwise_unary(const SimpleTensor<T> &src, ElementWiseUnary o ARM_COMPUTE_ERROR("Not implemented"); } } - return dst; } +template <> +SimpleTensor<int8_t> elementwise_unary(const SimpleTensor<int8_t> &src, SimpleTensor<int8_t> &dst, ElementWiseUnary op) +{ + if(dst.data_type() == DataType::QASYMM8_SIGNED) + { + SimpleTensor<float> src_tmp = convert_from_asymmetric(src); + SimpleTensor<float> dst_tmp(src.shape(), DataType::F32); + for(int i = 0; i < src.num_elements(); ++i) + { + switch(op) + { + case ElementWiseUnary::RSQRT: + if(src_tmp[i] != 0) + { + dst_tmp[i] = 1.f / std::sqrt(src_tmp[i]); + } + else + { + // rsqrt(0) give 'inf' so set to the maximum in int8: 127 + dst_tmp[i] = (127.0f - dst.quantization_info().uniform().offset) * dst.quantization_info().uniform().scale ; + } + break; + default: + ARM_COMPUTE_ERROR("Not implemented"); + } + } + dst = convert_to_asymmetric<int8_t>(dst_tmp, dst.quantization_info()); + } + else + { + ARM_COMPUTE_ERROR("Not implemented"); + } + return dst; +} +template <> +SimpleTensor<uint8_t> elementwise_unary(const SimpleTensor<uint8_t> &src, SimpleTensor<uint8_t> &dst, ElementWiseUnary op) +{ + if(dst.data_type() == DataType::QASYMM8) + { + SimpleTensor<float> src_tmp = convert_from_asymmetric(src); + SimpleTensor<float> dst_tmp(src.shape(), DataType::F32); + for(int i = 0; i < src.num_elements(); ++i) + { + switch(op) + { + case ElementWiseUnary::RSQRT: + if(src_tmp[i] != 0) + { + dst_tmp[i] = 1.f / std::sqrt(src_tmp[i]); + } + else + { + // rsqrt(0) give 'inf' so set to the maximum in uint8: 255 + dst_tmp[i] = (255.0f - dst.quantization_info().uniform().offset)* dst.quantization_info().uniform().scale; + } + break; + default: + ARM_COMPUTE_ERROR("Not implemented"); + } + } + dst = convert_to_asymmetric<uint8_t>(dst_tmp, dst.quantization_info()); + } + else + { + ARM_COMPUTE_ERROR("Not implemented"); + } + return dst; +} + +template SimpleTensor<float> elementwise_unary(const SimpleTensor<float> &src, SimpleTensor<float> &dst, ElementWiseUnary op); +template SimpleTensor<half> elementwise_unary(const SimpleTensor<half> &src, SimpleTensor<half> &dst, ElementWiseUnary op); +template SimpleTensor<int32_t> elementwise_unary(const SimpleTensor<int32_t> &src, SimpleTensor<int32_t> &dst, ElementWiseUnary op); -template SimpleTensor<float> elementwise_unary(const SimpleTensor<float> &src, ElementWiseUnary op); -template SimpleTensor<half> elementwise_unary(const SimpleTensor<half> &src, ElementWiseUnary op); -template SimpleTensor<int32_t> elementwise_unary(const SimpleTensor<int32_t> &src, ElementWiseUnary op); } // namespace reference } // namespace validation } // namespace test diff --git a/tests/validation/reference/ElementwiseUnary.h b/tests/validation/reference/ElementwiseUnary.h index be4a229a5b..ae7a49bce4 100644 --- a/tests/validation/reference/ElementwiseUnary.h +++ b/tests/validation/reference/ElementwiseUnary.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 Arm Limited. + * Copyright (c) 2018-2019, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -35,7 +35,7 @@ namespace validation namespace reference { template <typename T> -SimpleTensor<T> elementwise_unary(const SimpleTensor<T> &src, ElementWiseUnary op); +SimpleTensor<T> elementwise_unary(const SimpleTensor<T> &src, SimpleTensor<T> &dst, ElementWiseUnary op); } // namespace reference } // namespace validation } // namespace test |