diff options
Diffstat (limited to 'tests/validation/fixtures/ScaleFixture.h')
-rw-r--r-- | tests/validation/fixtures/ScaleFixture.h | 80 |
1 files changed, 48 insertions, 32 deletions
diff --git a/tests/validation/fixtures/ScaleFixture.h b/tests/validation/fixtures/ScaleFixture.h index b719a22fdf..86d89d71f7 100644 --- a/tests/validation/fixtures/ScaleFixture.h +++ b/tests/validation/fixtures/ScaleFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,15 +21,10 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef ARM_COMPUTE_TEST_SCALE_FIXTURE -#define ARM_COMPUTE_TEST_SCALE_FIXTURE - -#include "arm_compute/core/TensorShape.h" -#include "arm_compute/core/Types.h" -#include "tests/AssetsLibrary.h" -#include "tests/Globals.h" -#include "tests/IAccessor.h" -#include "tests/framework/Asserts.h" +#ifndef ACL_TESTS_VALIDATION_FIXTURES_SCALEFIXTURE_H +#define ACL_TESTS_VALIDATION_FIXTURES_SCALEFIXTURE_H + +#include "tests/framework/Asserts.h" // Required for ARM_COMPUTE_ASSERT #include "tests/framework/Fixture.h" #include "tests/validation/reference/Permute.h" #include "tests/validation/reference/Scale.h" @@ -44,23 +39,23 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ class ScaleValidationGenericFixture : public framework::Fixture { public: - template <typename...> void setup(TensorShape shape, DataType data_type, QuantizationInfo quantization_info, DataLayout data_layout, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy, - bool align_corners, bool mixed_layout) + bool align_corners, bool mixed_layout, QuantizationInfo output_quantization_info) { - _shape = shape; - _policy = policy; - _border_mode = border_mode; - _sampling_policy = sampling_policy; - _data_type = data_type; - _quantization_info = quantization_info; - _align_corners = align_corners; - _mixed_layout = mixed_layout; + _shape = shape; + _policy = policy; + _border_mode = border_mode; + _sampling_policy = sampling_policy; + _data_type = data_type; + _input_quantization_info = quantization_info; + _output_quantization_info = output_quantization_info; + _align_corners = align_corners; + _mixed_layout = mixed_layout; generate_scale(shape); - std::mt19937 generator(library->seed()); - std::uniform_int_distribution<uint8_t> distribution_u8(0, 255); + std::mt19937 generator(library->seed()); + std::uniform_int_distribution<uint32_t> distribution_u8(0, 255); _constant_border_value = static_cast<T>(distribution_u8(generator)); _target = compute_target(shape, data_layout); @@ -144,7 +139,7 @@ protected: } // Create tensors - TensorType src = create_tensor<TensorType>(shape, _data_type, 1, _quantization_info, data_layout); + TensorType src = create_tensor<TensorType>(shape, _data_type, 1, _input_quantization_info, data_layout); const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); @@ -152,7 +147,7 @@ protected: TensorShape shape_scaled(shape); shape_scaled.set(idx_width, shape[idx_width] * _scale_x, /* apply_dim_correction = */ false); shape_scaled.set(idx_height, shape[idx_height] * _scale_y, /* apply_dim_correction = */ false); - TensorType dst = create_tensor<TensorType>(shape_scaled, _data_type, 1, _quantization_info, data_layout); + TensorType dst = create_tensor<TensorType>(shape_scaled, _data_type, 1, _output_quantization_info, data_layout); // Create and configure function FunctionType scale; @@ -188,12 +183,12 @@ protected: SimpleTensor<T> compute_reference(const TensorShape &shape) { // Create reference - SimpleTensor<T> src{ shape, _data_type, 1, _quantization_info }; + SimpleTensor<T> src{ shape, _data_type, 1, _input_quantization_info }; // Fill reference fill(src); - return reference::scale<T>(src, _scale_x, _scale_y, _policy, _border_mode, _constant_border_value, _sampling_policy, /* ceil_policy_scale */ false, _align_corners); + return reference::scale<T>(src, _scale_x, _scale_y, _policy, _border_mode, _constant_border_value, _sampling_policy, /* ceil_policy_scale */ false, _align_corners, _output_quantization_info); } TensorType _target{}; @@ -204,7 +199,8 @@ protected: T _constant_border_value{}; SamplingPolicy _sampling_policy{}; DataType _data_type{}; - QuantizationInfo _quantization_info{}; + QuantizationInfo _input_quantization_info{}; + QuantizationInfo _output_quantization_info{}; bool _align_corners{ false }; bool _mixed_layout{ false }; float _scale_x{ 1.f }; @@ -215,7 +211,6 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ class ScaleValidationQuantizedFixture : public ScaleValidationGenericFixture<TensorType, AccessorType, FunctionType, T> { public: - template <typename...> void setup(TensorShape shape, DataType data_type, QuantizationInfo quantization_info, DataLayout data_layout, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy, bool align_corners) { @@ -227,14 +222,34 @@ public: border_mode, sampling_policy, align_corners, - mixed_layout); + mixed_layout, + quantization_info); + } +}; +template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false> +class ScaleValidationDifferentOutputQuantizedFixture : public ScaleValidationGenericFixture<TensorType, AccessorType, FunctionType, T> +{ +public: + void setup(TensorShape shape, DataType data_type, QuantizationInfo input_quantization_info, QuantizationInfo output_quantization_info, DataLayout data_layout, InterpolationPolicy policy, + BorderMode border_mode, SamplingPolicy sampling_policy, + bool align_corners) + { + ScaleValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, + data_type, + input_quantization_info, + data_layout, + policy, + border_mode, + sampling_policy, + align_corners, + mixed_layout, + output_quantization_info); } }; template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false> class ScaleValidationFixture : public ScaleValidationGenericFixture<TensorType, AccessorType, FunctionType, T> { public: - template <typename...> void setup(TensorShape shape, DataType data_type, DataLayout data_layout, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy, bool align_corners) { ScaleValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, @@ -245,10 +260,11 @@ public: border_mode, sampling_policy, align_corners, - mixed_layout); + mixed_layout, + QuantizationInfo()); } }; } // namespace validation } // namespace test } // namespace arm_compute -#endif /* ARM_COMPUTE_TEST_SCALE_FIXTURE */ +#endif // ACL_TESTS_VALIDATION_FIXTURES_SCALEFIXTURE_H |