aboutsummaryrefslogtreecommitdiff
path: root/tests/validation
diff options
context:
space:
mode:
authorRamy Elgammal <ramy.elgammal@arm.com>2023-01-30 04:56:47 +0000
committerRamy Elgammal <ramy.elgammal@arm.com>2023-03-17 12:45:26 +0000
commit14d7b535d48620f009efca576cc70fb6ea9ff20d (patch)
tree592ae31798825fba0b71b36adda008fb53381e9f /tests/validation
parent2b6ebfe4270b06b09e45f306e8384950aeca7e4e (diff)
downloadComputeLibrary-14d7b535d48620f009efca576cc70fb6ea9ff20d.tar.gz
Implementation of RSQRT for quantized int8
Resolves: COMPMID-5863 Change-Id: I9ff67face62826c1d335a6b941e8516be39bdac8 Signed-off-by: Ramy Elgammal <ramy.elgammal@arm.com> Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/c/VisualCompute/ComputeLibrary/+/488768 Tested-by: bsgcomp <bsgcomp@arm.com> Comments-Addressed: bsgcomp <bsgcomp@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9225 Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation')
-rw-r--r--tests/validation/CL/RsqrtLayer.cpp35
-rw-r--r--tests/validation/fixtures/ElementwiseUnaryFixture.h44
-rw-r--r--tests/validation/reference/ElementwiseUnary.cpp85
-rw-r--r--tests/validation/reference/ElementwiseUnary.h4
4 files changed, 141 insertions, 27 deletions
diff --git a/tests/validation/CL/RsqrtLayer.cpp b/tests/validation/CL/RsqrtLayer.cpp
index 936d853d34..2353bda8d3 100644
--- a/tests/validation/CL/RsqrtLayer.cpp
+++ b/tests/validation/CL/RsqrtLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 Arm Limited.
+ * Copyright (c) 2018-2019, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -42,8 +42,11 @@ namespace validation
{
namespace
{
-RelativeTolerance<float> tolerance_fp32(0.000001f);
-RelativeTolerance<float> tolerance_fp16(0.001f);
+RelativeTolerance<float> tolerance_fp32(0.000001f);
+RelativeTolerance<float> tolerance_fp16(0.001f);
+constexpr AbsoluteTolerance<uint8_t> tolerance_qasymm8(1); /**< Tolerance value for comparing reference's output against implementation's output for unsigned 8-bit asymmetric type */
+constexpr AbsoluteTolerance<int8_t> tolerance_qasymm8_s(1); /**< Tolerance value for comparing reference's output against implementation's output for signed 8-bit asymmetric type */
+
} // namespace
TEST_SUITE(CL)
TEST_SUITE(RsqrtLayer)
@@ -68,6 +71,8 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(
// *INDENT-ON*
template <typename T>
using CLRsqrtLayerFixture = RsqrtValidationFixture<CLTensor, CLAccessor, CLRsqrtLayer, T>;
+template <typename T>
+using CLRsqrtLayerQuantizedFixture = RsqrtQuantizedValidationFixture<CLTensor, CLAccessor, CLRsqrtLayer, T>;
TEST_SUITE(Float)
TEST_SUITE(FP16)
@@ -102,6 +107,30 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLRsqrtLayerFixture<float>, framework::DatasetM
TEST_SUITE_END() // FP32
TEST_SUITE_END() // Float
+TEST_SUITE(Quantized)
+TEST_SUITE(QASYMM8_SIGNED)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLRsqrtLayerQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
+ DataType::QASYMM8_SIGNED)),
+ framework::dataset::make("SrcQInfo", { QuantizationInfo(0.4044, -128) })),
+ framework::dataset::make("OutQInfo", { QuantizationInfo(0.0027, -128) })))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_qasymm8_s);
+}
+TEST_SUITE_END() // QASYMM8_SIGNED
+TEST_SUITE(QASYMM8)
+
+FIXTURE_DATA_TEST_CASE(RunSmall, CLRsqrtLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
+ DataType::QASYMM8)),
+ framework::dataset::make("SrcQInfo", { QuantizationInfo(0.4044, 0) })),
+ framework::dataset::make("OutQInfo", { QuantizationInfo(0.0027, 0) })))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_qasymm8);
+}
+TEST_SUITE_END() // QASYMM8
+TEST_SUITE_END() // Quantized
+
TEST_SUITE_END() // RsqrtLayer
TEST_SUITE_END() // CL
} // namespace validation
diff --git a/tests/validation/fixtures/ElementwiseUnaryFixture.h b/tests/validation/fixtures/ElementwiseUnaryFixture.h
index 7221226fd1..1dc4f03e99 100644
--- a/tests/validation/fixtures/ElementwiseUnaryFixture.h
+++ b/tests/validation/fixtures/ElementwiseUnaryFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2021 Arm Limited.
+ * Copyright (c) 2018-2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -44,11 +44,12 @@ class ElementWiseUnaryValidationFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape input_shape, DataType input_data_type, bool in_place, ElementWiseUnary op, bool use_dynamic_shape = false)
+ void setup(TensorShape input_shape, DataType input_data_type, bool in_place, ElementWiseUnary op,
+ bool use_dynamic_shape = false, QuantizationInfo qinfo = QuantizationInfo(), QuantizationInfo qinfo_out = QuantizationInfo())
{
_op = op;
- _target = compute_target(input_shape, input_data_type, in_place);
- _reference = compute_reference(input_shape, input_data_type);
+ _target = compute_target(input_shape, input_data_type, in_place, qinfo, qinfo_out);
+ _reference = compute_reference(input_shape, input_data_type, qinfo, qinfo_out);
_use_dynamic_shape = use_dynamic_shape;
}
@@ -69,8 +70,15 @@ protected:
}
case ElementWiseUnary::RSQRT:
{
- FloatDistributionType distribution{ FloatType(1.0f), FloatType(2.0f) };
- library->fill(tensor, distribution, i);
+ if(data_type == DataType::F32 || data_type == DataType::F16)
+ {
+ FloatDistributionType distribution{ FloatType(1.0f), FloatType(2.0f) };
+ library->fill(tensor, distribution, i);
+ }
+ else
+ {
+ library->fill_tensor_uniform(tensor, i);
+ }
break;
}
case ElementWiseUnary::ABS:
@@ -124,12 +132,11 @@ protected:
}
}
- TensorType compute_target(const TensorShape &shape, DataType data_type, bool in_place)
+ TensorType compute_target(const TensorShape &shape, DataType data_type, bool in_place, QuantizationInfo qinfo, QuantizationInfo qinfo_out)
{
// Create tensors
- TensorType src = create_tensor<TensorType>(shape, data_type);
- TensorType dst = create_tensor<TensorType>(shape, data_type);
-
+ TensorType src = create_tensor<TensorType>(shape, data_type, 1, qinfo);
+ TensorType dst = create_tensor<TensorType>(shape, data_type, 1, qinfo_out);
TensorType *actual_dst = in_place ? &src : &dst;
// if _use_dynamic_shape is true, this fixture will test scenario for dynamic shapes.
@@ -176,15 +183,16 @@ protected:
}
}
- SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type)
+ SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type, QuantizationInfo qinfo, QuantizationInfo qinfo_out)
{
// Create reference
- SimpleTensor<T> src{ shape, data_type };
+ SimpleTensor<T> src{ shape, data_type, 1, qinfo };
+ SimpleTensor<T> dst{ shape, data_type, 1, qinfo_out };
// Fill reference
fill(src, 0, data_type);
- return reference::elementwise_unary<T>(src, _op);
+ return reference::elementwise_unary<T>(src, dst, _op);
}
TensorType _target{};
@@ -192,6 +200,16 @@ protected:
ElementWiseUnary _op{};
bool _use_dynamic_shape{ false };
};
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class RsqrtQuantizedValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+ template <typename...>
+ void setup(const TensorShape &shape, DataType data_type, QuantizationInfo qinfo, QuantizationInfo qinfo_out)
+ {
+ ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::RSQRT, false, qinfo, qinfo_out);
+ }
+};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
class RsqrtValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>
diff --git a/tests/validation/reference/ElementwiseUnary.cpp b/tests/validation/reference/ElementwiseUnary.cpp
index 5333b53c15..d5218d772d 100644
--- a/tests/validation/reference/ElementwiseUnary.cpp
+++ b/tests/validation/reference/ElementwiseUnary.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2020, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -22,7 +22,8 @@
* SOFTWARE.
*/
#include "ElementwiseUnary.h"
-
+#include "tests/validation/Helpers.h"
+#include "utils/TypePrinter.h"
namespace arm_compute
{
namespace test
@@ -32,10 +33,8 @@ namespace validation
namespace reference
{
template <typename T>
-SimpleTensor<T> elementwise_unary(const SimpleTensor<T> &src, ElementWiseUnary op)
+SimpleTensor<T> elementwise_unary(const SimpleTensor<T> &src, SimpleTensor<T> &dst, ElementWiseUnary op)
{
- SimpleTensor<T> dst(src.shape(), src.data_type());
-
for(int i = 0; i < src.num_elements(); ++i)
{
switch(op)
@@ -65,13 +64,81 @@ SimpleTensor<T> elementwise_unary(const SimpleTensor<T> &src, ElementWiseUnary o
ARM_COMPUTE_ERROR("Not implemented");
}
}
-
return dst;
}
+template <>
+SimpleTensor<int8_t> elementwise_unary(const SimpleTensor<int8_t> &src, SimpleTensor<int8_t> &dst, ElementWiseUnary op)
+{
+ if(dst.data_type() == DataType::QASYMM8_SIGNED)
+ {
+ SimpleTensor<float> src_tmp = convert_from_asymmetric(src);
+ SimpleTensor<float> dst_tmp(src.shape(), DataType::F32);
+ for(int i = 0; i < src.num_elements(); ++i)
+ {
+ switch(op)
+ {
+ case ElementWiseUnary::RSQRT:
+ if(src_tmp[i] != 0)
+ {
+ dst_tmp[i] = 1.f / std::sqrt(src_tmp[i]);
+ }
+ else
+ {
+ // rsqrt(0) give 'inf' so set to the maximum in int8: 127
+ dst_tmp[i] = (127.0f - dst.quantization_info().uniform().offset) * dst.quantization_info().uniform().scale ;
+ }
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Not implemented");
+ }
+ }
+ dst = convert_to_asymmetric<int8_t>(dst_tmp, dst.quantization_info());
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR("Not implemented");
+ }
+ return dst;
+}
+template <>
+SimpleTensor<uint8_t> elementwise_unary(const SimpleTensor<uint8_t> &src, SimpleTensor<uint8_t> &dst, ElementWiseUnary op)
+{
+ if(dst.data_type() == DataType::QASYMM8)
+ {
+ SimpleTensor<float> src_tmp = convert_from_asymmetric(src);
+ SimpleTensor<float> dst_tmp(src.shape(), DataType::F32);
+ for(int i = 0; i < src.num_elements(); ++i)
+ {
+ switch(op)
+ {
+ case ElementWiseUnary::RSQRT:
+ if(src_tmp[i] != 0)
+ {
+ dst_tmp[i] = 1.f / std::sqrt(src_tmp[i]);
+ }
+ else
+ {
+ // rsqrt(0) give 'inf' so set to the maximum in uint8: 255
+ dst_tmp[i] = (255.0f - dst.quantization_info().uniform().offset)* dst.quantization_info().uniform().scale;
+ }
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Not implemented");
+ }
+ }
+ dst = convert_to_asymmetric<uint8_t>(dst_tmp, dst.quantization_info());
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR("Not implemented");
+ }
+ return dst;
+}
+
+template SimpleTensor<float> elementwise_unary(const SimpleTensor<float> &src, SimpleTensor<float> &dst, ElementWiseUnary op);
+template SimpleTensor<half> elementwise_unary(const SimpleTensor<half> &src, SimpleTensor<half> &dst, ElementWiseUnary op);
+template SimpleTensor<int32_t> elementwise_unary(const SimpleTensor<int32_t> &src, SimpleTensor<int32_t> &dst, ElementWiseUnary op);
-template SimpleTensor<float> elementwise_unary(const SimpleTensor<float> &src, ElementWiseUnary op);
-template SimpleTensor<half> elementwise_unary(const SimpleTensor<half> &src, ElementWiseUnary op);
-template SimpleTensor<int32_t> elementwise_unary(const SimpleTensor<int32_t> &src, ElementWiseUnary op);
} // namespace reference
} // namespace validation
} // namespace test
diff --git a/tests/validation/reference/ElementwiseUnary.h b/tests/validation/reference/ElementwiseUnary.h
index be4a229a5b..ae7a49bce4 100644
--- a/tests/validation/reference/ElementwiseUnary.h
+++ b/tests/validation/reference/ElementwiseUnary.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 Arm Limited.
+ * Copyright (c) 2018-2019, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -35,7 +35,7 @@ namespace validation
namespace reference
{
template <typename T>
-SimpleTensor<T> elementwise_unary(const SimpleTensor<T> &src, ElementWiseUnary op);
+SimpleTensor<T> elementwise_unary(const SimpleTensor<T> &src, SimpleTensor<T> &dst, ElementWiseUnary op);
} // namespace reference
} // namespace validation
} // namespace test