aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Android.bp1
-rw-r--r--SConscript1
-rw-r--r--src/core/CL/cl_kernels/common/elementwise_unary_quantized.cl77
-rw-r--r--src/gpu/cl/ClKernelLibrary.cpp5
-rw-r--r--src/gpu/cl/kernels/ClElementwiseUnaryKernel.cpp41
-rw-r--r--tests/validation/CL/RsqrtLayer.cpp35
-rw-r--r--tests/validation/fixtures/ElementwiseUnaryFixture.h44
-rw-r--r--tests/validation/reference/ElementwiseUnary.cpp85
-rw-r--r--tests/validation/reference/ElementwiseUnary.h4
9 files changed, 253 insertions, 40 deletions
diff --git a/Android.bp b/Android.bp
index 42116d78ca..848ae1049b 100644
--- a/Android.bp
+++ b/Android.bp
@@ -27,6 +27,7 @@ opencl_srcs = [
"src/core/CL/cl_kernels/common/elementwise_operation.cl",
"src/core/CL/cl_kernels/common/elementwise_operation_quantized.cl",
"src/core/CL/cl_kernels/common/elementwise_unary.cl",
+ "src/core/CL/cl_kernels/common/elementwise_unary_quantized.cl",
"src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/fp_post_ops_act_eltwise_op_act.h",
"src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_native.cl",
"src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped.cl",
diff --git a/SConscript b/SConscript
index 205764b9a7..03b94b6bd1 100644
--- a/SConscript
+++ b/SConscript
@@ -369,6 +369,7 @@ if env['opencl'] and env['embed_kernels']:
'src/core/CL/cl_kernels/common/elementwise_operation.cl',
'src/core/CL/cl_kernels/common/elementwise_operation_quantized.cl',
'src/core/CL/cl_kernels/common/elementwise_unary.cl',
+ 'src/core/CL/cl_kernels/common/elementwise_unary_quantized.cl',
'src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_native.cl',
'src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped.cl',
'src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped_only_rhs.cl',
diff --git a/src/core/CL/cl_kernels/common/elementwise_unary_quantized.cl b/src/core/CL/cl_kernels/common/elementwise_unary_quantized.cl
new file mode 100644
index 0000000000..2e4cdc53fe
--- /dev/null
+++ b/src/core/CL/cl_kernels/common/elementwise_unary_quantized.cl
@@ -0,0 +1,77 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "helpers.h"
+
+#if defined(DATA_TYPE) && defined(OPERATION)
+// Calculate reverse square root
+#define rsqrt_op(input) rsqrt(input)
+#if defined(VEC_SIZE)
+#define VEC_FLOAT VEC_DATA_TYPE(float, VEC_SIZE)
+#define VEC_INT VEC_DATA_TYPE(int, VEC_SIZE)
+#define VEC_TYPE VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
+#endif // defined(VEC_SIZE)
+
+/** Applies element wise unary operator in a tensor.
+ *
+ * @param[in] in_ptr Pointer to the source image. Supported data types: QASYMM8/QASYMM8_SIGNED.
+ * @param[in] in_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] in_step_x in_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] in_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] in_step_y in_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] in_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] in_step_z in_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] in_offset_first_element_in_bytes Offset of the first element in the source image
+ * @param[out] out_ptr Pointer to the destination image. Supported data types: QASYMM8/QASYMM8_SIGNED.
+ * @param[in] out_stride_x Stride of the destination image in X dimension (in bytes)
+ * @param[in] out_step_x out_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] out_step_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] out_step_y out_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] out_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] out_step_z out_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] out_offset_first_element_in_bytes Offset of the first element in the destination image
+ */
+__kernel void elementwise_unary_quantized(
+ TENSOR3D_DECLARATION(in),
+ TENSOR3D_DECLARATION(out))
+{
+ Tensor3D in = CONVERT_TO_TENSOR3D_STRUCT(in);
+ Tensor3D out = CONVERT_TO_TENSOR3D_STRUCT(out);
+
+ // Check if access on width gets out of bounds
+ // If it does shift access vector to access elements within bounds
+ const int xi = (int)(get_global_id(0) * VEC_SIZE);
+ in.ptr -= max(xi - (int)LAST_ACCESSED_X, 0) * in_stride_x;
+ out.ptr -= max(xi - (int)LAST_ACCESSED_X, 0) * out_stride_x;
+
+ VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
+ data = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)in.ptr);
+ VEC_DATA_TYPE(float, VEC_SIZE)
+ data_f32 = CONVERT(data, VEC_FLOAT);
+ data_f32 = (data_f32 - (float)OFFSET_IN) * (float)SCALE_IN;
+ VEC_INT qres_int = CONVERT_SAT((OPERATION(data_f32) / ((VEC_FLOAT)(float)SCALE_OUT)), VEC_INT) + ((VEC_INT)((int)OFFSET_OUT));
+ const VEC_TYPE qres = CONVERT_SAT(qres_int, VEC_TYPE);
+ VSTORE(VEC_SIZE)
+ (qres, 0, (__global DATA_TYPE *)out.ptr);
+}
+#endif // defined(DATA_TYPE) && defined(OPERATION)
diff --git a/src/gpu/cl/ClKernelLibrary.cpp b/src/gpu/cl/ClKernelLibrary.cpp
index 482e8c341d..8099071fcd 100644
--- a/src/gpu/cl/ClKernelLibrary.cpp
+++ b/src/gpu/cl/ClKernelLibrary.cpp
@@ -235,6 +235,7 @@ const std::map<std::string, std::string> ClKernelLibrary::_kernel_program_map =
{ "elementwise_operation_SQUARED_DIFF_quantized", "common/elementwise_operation_quantized.cl" },
{ "elementwise_operation_PRELU_quantized", "common/elementwise_operation_quantized.cl" },
{ "elementwise_unary", "common/elementwise_unary.cl" },
+ { "elementwise_unary_quantized", "common/elementwise_unary_quantized.cl" },
{ "fft_digit_reverse_axis_0", "common/fft_digit_reverse.cl" },
{ "fft_digit_reverse_axis_1", "common/fft_digit_reverse.cl" },
{ "fft_radix_2_first_stage_axis_0", "common/fft.cl" },
@@ -572,6 +573,10 @@ const std::map<std::string, std::string> ClKernelLibrary::_program_source_map =
#include "./cl_kernels/common/elementwise_unary.clembed"
},
{
+ "common/elementwise_unary_quantized.cl",
+#include "./cl_kernels/common/elementwise_unary_quantized.clembed"
+ },
+ {
"common/fft.cl",
#include "./cl_kernels/common/fft.clembed"
},
diff --git a/src/gpu/cl/kernels/ClElementwiseUnaryKernel.cpp b/src/gpu/cl/kernels/ClElementwiseUnaryKernel.cpp
index 1f09515b86..40b1eaca1f 100644
--- a/src/gpu/cl/kernels/ClElementwiseUnaryKernel.cpp
+++ b/src/gpu/cl/kernels/ClElementwiseUnaryKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2021 Arm Limited.
+ * Copyright (c) 2018-2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -38,6 +38,8 @@ namespace kernels
{
namespace
{
+constexpr unsigned int vector_size_byte_opencl = 16;
+
Status validate_arguments(const ITensorInfo &src, const ITensorInfo &dst, const ElementWiseUnary op)
{
ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(&src);
@@ -49,6 +51,10 @@ Status validate_arguments(const ITensorInfo &src, const ITensorInfo &dst, const
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src, 1, DataType::F16, DataType::F32, DataType::S32);
}
+ else if(op == ElementWiseUnary::RSQRT) // Allow quantized types for only RSQRT.
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src, 1, DataType::F16, DataType::F32, DataType::QASYMM8, DataType::QASYMM8_SIGNED);
+ }
else
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src, 1, DataType::F16, DataType::F32);
@@ -78,17 +84,29 @@ void ClElementWiseUnaryKernel::configure(const CLCompileContext &compile_context
auto padding_info = get_padding_info({ src, dst });
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*src, *dst, op));
+ const unsigned int num_elems_processed_per_iteration = adjust_vec_size(vector_size_byte_opencl / dst->element_size(), dst->dimension(0));
- const std::string kernel_name = "elementwise_unary";
- const int vec_size_x = 16 / dst->element_size();
- const int dst_width_x = dst->tensor_shape().x();
- const bool multi_access_x = (dst_width_x / vec_size_x > 0);
-
+ std::string kernel_name = "elementwise_unary";
+ const int vec_size_x = num_elems_processed_per_iteration;
+ const int dst_width_x = dst->dimension(0);
+ if(is_data_type_quantized(src->data_type()))
+ {
+ kernel_name += "_quantized";
+ }
// Set kernel build options
CLBuildOptions build_opts;
build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(src->data_type()));
- build_opts.add_option_if(multi_access_x, "-DVEC_SIZE=" + support::cpp11::to_string(vec_size_x));
- build_opts.add_option_if(multi_access_x, "-DLAST_ACCESSED_X=" + support::cpp11::to_string(std::max<int>(dst_width_x - vec_size_x, 0)));
+ build_opts.add_option("-DVEC_SIZE=" + support::cpp11::to_string(vec_size_x));
+ build_opts.add_option("-DLAST_ACCESSED_X=" + support::cpp11::to_string(std::max<int>(dst_width_x - vec_size_x, 0)));
+ if(is_data_type_quantized(src->data_type()))
+ {
+ const UniformQuantizationInfo iqinfo = src->quantization_info().uniform();
+ const UniformQuantizationInfo oqinfo = dst->quantization_info().uniform();
+ build_opts.add_option("-DOFFSET_IN=" + support::cpp11::to_string(iqinfo.offset));
+ build_opts.add_option("-DOFFSET_OUT=" + support::cpp11::to_string(oqinfo.offset));
+ build_opts.add_option("-DSCALE_IN=" + float_to_string_with_full_precision(iqinfo.scale));
+ build_opts.add_option("-DSCALE_OUT=" + float_to_string_with_full_precision(oqinfo.scale));
+ }
switch(op)
{
case ElementWiseUnary::RSQRT:
@@ -124,11 +142,8 @@ void ClElementWiseUnaryKernel::configure(const CLCompileContext &compile_context
// Configure kernel window
Window win = calculate_max_window(*dst);
- if(multi_access_x)
- {
- win.set(Window::DimX,
- Window::Dimension(win.x().start(), ceil_to_multiple(win.x().end(), vec_size_x), vec_size_x));
- }
+ win.set(Window::DimX, Window::Dimension(win.x().start(), ceil_to_multiple(win.x().end(), vec_size_x), vec_size_x));
+
ICLKernel::configure_internal(win);
ARM_COMPUTE_ERROR_ON(has_padding_changed(padding_info));
diff --git a/tests/validation/CL/RsqrtLayer.cpp b/tests/validation/CL/RsqrtLayer.cpp
index 936d853d34..2353bda8d3 100644
--- a/tests/validation/CL/RsqrtLayer.cpp
+++ b/tests/validation/CL/RsqrtLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 Arm Limited.
+ * Copyright (c) 2018-2019, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -42,8 +42,11 @@ namespace validation
{
namespace
{
-RelativeTolerance<float> tolerance_fp32(0.000001f);
-RelativeTolerance<float> tolerance_fp16(0.001f);
+RelativeTolerance<float> tolerance_fp32(0.000001f);
+RelativeTolerance<float> tolerance_fp16(0.001f);
+constexpr AbsoluteTolerance<uint8_t> tolerance_qasymm8(1); /**< Tolerance value for comparing reference's output against implementation's output for unsigned 8-bit asymmetric type */
+constexpr AbsoluteTolerance<int8_t> tolerance_qasymm8_s(1); /**< Tolerance value for comparing reference's output against implementation's output for signed 8-bit asymmetric type */
+
} // namespace
TEST_SUITE(CL)
TEST_SUITE(RsqrtLayer)
@@ -68,6 +71,8 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(
// *INDENT-ON*
template <typename T>
using CLRsqrtLayerFixture = RsqrtValidationFixture<CLTensor, CLAccessor, CLRsqrtLayer, T>;
+template <typename T>
+using CLRsqrtLayerQuantizedFixture = RsqrtQuantizedValidationFixture<CLTensor, CLAccessor, CLRsqrtLayer, T>;
TEST_SUITE(Float)
TEST_SUITE(FP16)
@@ -102,6 +107,30 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLRsqrtLayerFixture<float>, framework::DatasetM
TEST_SUITE_END() // FP32
TEST_SUITE_END() // Float
+TEST_SUITE(Quantized)
+TEST_SUITE(QASYMM8_SIGNED)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLRsqrtLayerQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
+ DataType::QASYMM8_SIGNED)),
+ framework::dataset::make("SrcQInfo", { QuantizationInfo(0.4044, -128) })),
+ framework::dataset::make("OutQInfo", { QuantizationInfo(0.0027, -128) })))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_qasymm8_s);
+}
+TEST_SUITE_END() // QASYMM8_SIGNED
+TEST_SUITE(QASYMM8)
+
+FIXTURE_DATA_TEST_CASE(RunSmall, CLRsqrtLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
+ DataType::QASYMM8)),
+ framework::dataset::make("SrcQInfo", { QuantizationInfo(0.4044, 0) })),
+ framework::dataset::make("OutQInfo", { QuantizationInfo(0.0027, 0) })))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_qasymm8);
+}
+TEST_SUITE_END() // QASYMM8
+TEST_SUITE_END() // Quantized
+
TEST_SUITE_END() // RsqrtLayer
TEST_SUITE_END() // CL
} // namespace validation
diff --git a/tests/validation/fixtures/ElementwiseUnaryFixture.h b/tests/validation/fixtures/ElementwiseUnaryFixture.h
index 7221226fd1..1dc4f03e99 100644
--- a/tests/validation/fixtures/ElementwiseUnaryFixture.h
+++ b/tests/validation/fixtures/ElementwiseUnaryFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2021 Arm Limited.
+ * Copyright (c) 2018-2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -44,11 +44,12 @@ class ElementWiseUnaryValidationFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape input_shape, DataType input_data_type, bool in_place, ElementWiseUnary op, bool use_dynamic_shape = false)
+ void setup(TensorShape input_shape, DataType input_data_type, bool in_place, ElementWiseUnary op,
+ bool use_dynamic_shape = false, QuantizationInfo qinfo = QuantizationInfo(), QuantizationInfo qinfo_out = QuantizationInfo())
{
_op = op;
- _target = compute_target(input_shape, input_data_type, in_place);
- _reference = compute_reference(input_shape, input_data_type);
+ _target = compute_target(input_shape, input_data_type, in_place, qinfo, qinfo_out);
+ _reference = compute_reference(input_shape, input_data_type, qinfo, qinfo_out);
_use_dynamic_shape = use_dynamic_shape;
}
@@ -69,8 +70,15 @@ protected:
}
case ElementWiseUnary::RSQRT:
{
- FloatDistributionType distribution{ FloatType(1.0f), FloatType(2.0f) };
- library->fill(tensor, distribution, i);
+ if(data_type == DataType::F32 || data_type == DataType::F16)
+ {
+ FloatDistributionType distribution{ FloatType(1.0f), FloatType(2.0f) };
+ library->fill(tensor, distribution, i);
+ }
+ else
+ {
+ library->fill_tensor_uniform(tensor, i);
+ }
break;
}
case ElementWiseUnary::ABS:
@@ -124,12 +132,11 @@ protected:
}
}
- TensorType compute_target(const TensorShape &shape, DataType data_type, bool in_place)
+ TensorType compute_target(const TensorShape &shape, DataType data_type, bool in_place, QuantizationInfo qinfo, QuantizationInfo qinfo_out)
{
// Create tensors
- TensorType src = create_tensor<TensorType>(shape, data_type);
- TensorType dst = create_tensor<TensorType>(shape, data_type);
-
+ TensorType src = create_tensor<TensorType>(shape, data_type, 1, qinfo);
+ TensorType dst = create_tensor<TensorType>(shape, data_type, 1, qinfo_out);
TensorType *actual_dst = in_place ? &src : &dst;
// if _use_dynamic_shape is true, this fixture will test scenario for dynamic shapes.
@@ -176,15 +183,16 @@ protected:
}
}
- SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type)
+ SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type, QuantizationInfo qinfo, QuantizationInfo qinfo_out)
{
// Create reference
- SimpleTensor<T> src{ shape, data_type };
+ SimpleTensor<T> src{ shape, data_type, 1, qinfo };
+ SimpleTensor<T> dst{ shape, data_type, 1, qinfo_out };
// Fill reference
fill(src, 0, data_type);
- return reference::elementwise_unary<T>(src, _op);
+ return reference::elementwise_unary<T>(src, dst, _op);
}
TensorType _target{};
@@ -192,6 +200,16 @@ protected:
ElementWiseUnary _op{};
bool _use_dynamic_shape{ false };
};
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class RsqrtQuantizedValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+ template <typename...>
+ void setup(const TensorShape &shape, DataType data_type, QuantizationInfo qinfo, QuantizationInfo qinfo_out)
+ {
+ ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::RSQRT, false, qinfo, qinfo_out);
+ }
+};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
class RsqrtValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>
diff --git a/tests/validation/reference/ElementwiseUnary.cpp b/tests/validation/reference/ElementwiseUnary.cpp
index 5333b53c15..d5218d772d 100644
--- a/tests/validation/reference/ElementwiseUnary.cpp
+++ b/tests/validation/reference/ElementwiseUnary.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2020, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -22,7 +22,8 @@
* SOFTWARE.
*/
#include "ElementwiseUnary.h"
-
+#include "tests/validation/Helpers.h"
+#include "utils/TypePrinter.h"
namespace arm_compute
{
namespace test
@@ -32,10 +33,8 @@ namespace validation
namespace reference
{
template <typename T>
-SimpleTensor<T> elementwise_unary(const SimpleTensor<T> &src, ElementWiseUnary op)
+SimpleTensor<T> elementwise_unary(const SimpleTensor<T> &src, SimpleTensor<T> &dst, ElementWiseUnary op)
{
- SimpleTensor<T> dst(src.shape(), src.data_type());
-
for(int i = 0; i < src.num_elements(); ++i)
{
switch(op)
@@ -65,13 +64,81 @@ SimpleTensor<T> elementwise_unary(const SimpleTensor<T> &src, ElementWiseUnary o
ARM_COMPUTE_ERROR("Not implemented");
}
}
-
return dst;
}
+template <>
+SimpleTensor<int8_t> elementwise_unary(const SimpleTensor<int8_t> &src, SimpleTensor<int8_t> &dst, ElementWiseUnary op)
+{
+ if(dst.data_type() == DataType::QASYMM8_SIGNED)
+ {
+ SimpleTensor<float> src_tmp = convert_from_asymmetric(src);
+ SimpleTensor<float> dst_tmp(src.shape(), DataType::F32);
+ for(int i = 0; i < src.num_elements(); ++i)
+ {
+ switch(op)
+ {
+ case ElementWiseUnary::RSQRT:
+ if(src_tmp[i] != 0)
+ {
+ dst_tmp[i] = 1.f / std::sqrt(src_tmp[i]);
+ }
+ else
+ {
+ // rsqrt(0) give 'inf' so set to the maximum in int8: 127
+ dst_tmp[i] = (127.0f - dst.quantization_info().uniform().offset) * dst.quantization_info().uniform().scale ;
+ }
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Not implemented");
+ }
+ }
+ dst = convert_to_asymmetric<int8_t>(dst_tmp, dst.quantization_info());
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR("Not implemented");
+ }
+ return dst;
+}
+template <>
+SimpleTensor<uint8_t> elementwise_unary(const SimpleTensor<uint8_t> &src, SimpleTensor<uint8_t> &dst, ElementWiseUnary op)
+{
+ if(dst.data_type() == DataType::QASYMM8)
+ {
+ SimpleTensor<float> src_tmp = convert_from_asymmetric(src);
+ SimpleTensor<float> dst_tmp(src.shape(), DataType::F32);
+ for(int i = 0; i < src.num_elements(); ++i)
+ {
+ switch(op)
+ {
+ case ElementWiseUnary::RSQRT:
+ if(src_tmp[i] != 0)
+ {
+ dst_tmp[i] = 1.f / std::sqrt(src_tmp[i]);
+ }
+ else
+ {
+ // rsqrt(0) give 'inf' so set to the maximum in uint8: 255
+ dst_tmp[i] = (255.0f - dst.quantization_info().uniform().offset)* dst.quantization_info().uniform().scale;
+ }
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Not implemented");
+ }
+ }
+ dst = convert_to_asymmetric<uint8_t>(dst_tmp, dst.quantization_info());
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR("Not implemented");
+ }
+ return dst;
+}
+
+template SimpleTensor<float> elementwise_unary(const SimpleTensor<float> &src, SimpleTensor<float> &dst, ElementWiseUnary op);
+template SimpleTensor<half> elementwise_unary(const SimpleTensor<half> &src, SimpleTensor<half> &dst, ElementWiseUnary op);
+template SimpleTensor<int32_t> elementwise_unary(const SimpleTensor<int32_t> &src, SimpleTensor<int32_t> &dst, ElementWiseUnary op);
-template SimpleTensor<float> elementwise_unary(const SimpleTensor<float> &src, ElementWiseUnary op);
-template SimpleTensor<half> elementwise_unary(const SimpleTensor<half> &src, ElementWiseUnary op);
-template SimpleTensor<int32_t> elementwise_unary(const SimpleTensor<int32_t> &src, ElementWiseUnary op);
} // namespace reference
} // namespace validation
} // namespace test
diff --git a/tests/validation/reference/ElementwiseUnary.h b/tests/validation/reference/ElementwiseUnary.h
index be4a229a5b..ae7a49bce4 100644
--- a/tests/validation/reference/ElementwiseUnary.h
+++ b/tests/validation/reference/ElementwiseUnary.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 Arm Limited.
+ * Copyright (c) 2018-2019, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -35,7 +35,7 @@ namespace validation
namespace reference
{
template <typename T>
-SimpleTensor<T> elementwise_unary(const SimpleTensor<T> &src, ElementWiseUnary op);
+SimpleTensor<T> elementwise_unary(const SimpleTensor<T> &src, SimpleTensor<T> &dst, ElementWiseUnary op);
} // namespace reference
} // namespace validation
} // namespace test