diff options
Diffstat (limited to 'tests/validation/fixtures/ScaleFixture.h')
-rw-r--r-- | tests/validation/fixtures/ScaleFixture.h | 148 |
1 files changed, 97 insertions, 51 deletions
diff --git a/tests/validation/fixtures/ScaleFixture.h b/tests/validation/fixtures/ScaleFixture.h index cf3c5c818f..86d89d71f7 100644 --- a/tests/validation/fixtures/ScaleFixture.h +++ b/tests/validation/fixtures/ScaleFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 ARM Limited. + * Copyright (c) 2017-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,15 +21,10 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef ARM_COMPUTE_TEST_SCALE_FIXTURE -#define ARM_COMPUTE_TEST_SCALE_FIXTURE - -#include "arm_compute/core/TensorShape.h" -#include "arm_compute/core/Types.h" -#include "tests/AssetsLibrary.h" -#include "tests/Globals.h" -#include "tests/IAccessor.h" -#include "tests/framework/Asserts.h" +#ifndef ACL_TESTS_VALIDATION_FIXTURES_SCALEFIXTURE_H +#define ACL_TESTS_VALIDATION_FIXTURES_SCALEFIXTURE_H + +#include "tests/framework/Asserts.h" // Required for ARM_COMPUTE_ASSERT #include "tests/framework/Fixture.h" #include "tests/validation/reference/Permute.h" #include "tests/validation/reference/Scale.h" @@ -44,22 +39,23 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ class ScaleValidationGenericFixture : public framework::Fixture { public: - template <typename...> void setup(TensorShape shape, DataType data_type, QuantizationInfo quantization_info, DataLayout data_layout, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy, - bool align_corners) + bool align_corners, bool mixed_layout, QuantizationInfo output_quantization_info) { - _shape = shape; - _policy = policy; - _border_mode = border_mode; - _sampling_policy = sampling_policy; - _data_type = data_type; - _quantization_info = quantization_info; - _align_corners = align_corners && _policy == InterpolationPolicy::BILINEAR && _sampling_policy == SamplingPolicy::TOP_LEFT; + _shape = shape; + _policy = policy; + _border_mode = border_mode; + _sampling_policy = sampling_policy; + _data_type = data_type; + _input_quantization_info = quantization_info; + _output_quantization_info = output_quantization_info; + _align_corners = align_corners; + _mixed_layout = mixed_layout; generate_scale(shape); - std::mt19937 generator(library->seed()); - std::uniform_int_distribution<uint8_t> distribution_u8(0, 255); + std::mt19937 generator(library->seed()); + std::uniform_int_distribution<uint32_t> distribution_u8(0, 255); _constant_border_value = static_cast<T>(distribution_u8(generator)); _target = compute_target(shape, data_layout); @@ -67,6 +63,21 @@ public: } protected: + void mix_layout(FunctionType &layer, TensorType &src, TensorType &dst) + { + const DataLayout data_layout = src.info()->data_layout(); + // Test Multi DataLayout graph cases, when the data layout changes after configure + src.info()->set_data_layout(data_layout == DataLayout::NCHW ? DataLayout::NHWC : DataLayout::NCHW); + dst.info()->set_data_layout(data_layout == DataLayout::NCHW ? DataLayout::NHWC : DataLayout::NCHW); + + // Compute Convolution function + layer.run(); + + // Reinstating original data layout for the test suite to properly check the values + src.info()->set_data_layout(data_layout); + dst.info()->set_data_layout(data_layout); + } + void generate_scale(const TensorShape &shape) { static constexpr float _min_scale{ 0.25f }; @@ -74,9 +85,8 @@ protected: constexpr float max_width{ 8192.0f }; constexpr float max_height{ 6384.0f }; - - const float min_width = _align_corners ? 2.f : 1.f; - const float min_height = _align_corners ? 2.f : 1.f; + const float min_width{ 1.f }; + const float min_height{ 1.f }; std::mt19937 generator(library->seed()); std::uniform_real_distribution<float> distribution_float(_min_scale, _max_scale); @@ -99,9 +109,15 @@ protected: template <typename U> void fill(U &&tensor) { - if(is_data_type_float(_data_type)) + if(tensor.data_type() == DataType::F32) { - library->fill_tensor_uniform(tensor, 0); + std::uniform_real_distribution<float> distribution(-5.0f, 5.0f); + library->fill(tensor, distribution, 0); + } + else if(tensor.data_type() == DataType::F16) + { + arm_compute::utils::uniform_real_distribution_16bit<half> distribution{ -5.0f, 5.0f }; + library->fill(tensor, distribution, 0); } else if(is_data_type_quantized(tensor.data_type())) { @@ -110,9 +126,7 @@ protected: } else { - // Restrict range for float to avoid any floating point issues - std::uniform_real_distribution<> distribution(-5.0f, 5.0f); - library->fill(tensor, distribution, 0); + library->fill_tensor_uniform(tensor, 0); } } @@ -125,48 +139,56 @@ protected: } // Create tensors - TensorType src = create_tensor<TensorType>(shape, _data_type, 1, _quantization_info, data_layout); + TensorType src = create_tensor<TensorType>(shape, _data_type, 1, _input_quantization_info, data_layout); const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); TensorShape shape_scaled(shape); - shape_scaled.set(idx_width, shape[idx_width] * _scale_x); - shape_scaled.set(idx_height, shape[idx_height] * _scale_y); - TensorType dst = create_tensor<TensorType>(shape_scaled, _data_type, 1, _quantization_info, data_layout); + shape_scaled.set(idx_width, shape[idx_width] * _scale_x, /* apply_dim_correction = */ false); + shape_scaled.set(idx_height, shape[idx_height] * _scale_y, /* apply_dim_correction = */ false); + TensorType dst = create_tensor<TensorType>(shape_scaled, _data_type, 1, _output_quantization_info, data_layout); // Create and configure function FunctionType scale; - scale.configure(&src, &dst, _policy, _border_mode, _constant_border_value, _sampling_policy, /* use_padding */ true, _align_corners); + scale.configure(&src, &dst, ScaleKernelInfo{ _policy, _border_mode, _constant_border_value, _sampling_policy, /* use_padding */ false, _align_corners }); + + ARM_COMPUTE_ASSERT(src.info()->is_resizable()); + ARM_COMPUTE_ASSERT(dst.info()->is_resizable()); - ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); + add_padding_x({ &src, &dst }, data_layout); // Allocate tensors src.allocator()->allocate(); dst.allocator()->allocate(); - ARM_COMPUTE_EXPECT(!src.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_ASSERT(!src.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!dst.info()->is_resizable()); // Fill tensors fill(AccessorType(src)); - // Compute function - scale.run(); - + if(_mixed_layout) + { + mix_layout(scale, src, dst); + } + else + { + // Compute function + scale.run(); + } return dst; } SimpleTensor<T> compute_reference(const TensorShape &shape) { // Create reference - SimpleTensor<T> src{ shape, _data_type, 1, _quantization_info }; + SimpleTensor<T> src{ shape, _data_type, 1, _input_quantization_info }; // Fill reference fill(src); - return reference::scale<T>(src, _scale_x, _scale_y, _policy, _border_mode, _constant_border_value, _sampling_policy, /* ceil_policy_scale */ false, _align_corners); + return reference::scale<T>(src, _scale_x, _scale_y, _policy, _border_mode, _constant_border_value, _sampling_policy, /* ceil_policy_scale */ false, _align_corners, _output_quantization_info); } TensorType _target{}; @@ -177,17 +199,18 @@ protected: T _constant_border_value{}; SamplingPolicy _sampling_policy{}; DataType _data_type{}; - QuantizationInfo _quantization_info{}; + QuantizationInfo _input_quantization_info{}; + QuantizationInfo _output_quantization_info{}; bool _align_corners{ false }; + bool _mixed_layout{ false }; float _scale_x{ 1.f }; float _scale_y{ 1.f }; }; -template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false> class ScaleValidationQuantizedFixture : public ScaleValidationGenericFixture<TensorType, AccessorType, FunctionType, T> { public: - template <typename...> void setup(TensorShape shape, DataType data_type, QuantizationInfo quantization_info, DataLayout data_layout, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy, bool align_corners) { @@ -198,14 +221,35 @@ public: policy, border_mode, sampling_policy, - align_corners); + align_corners, + mixed_layout, + quantization_info); } }; -template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false> +class ScaleValidationDifferentOutputQuantizedFixture : public ScaleValidationGenericFixture<TensorType, AccessorType, FunctionType, T> +{ +public: + void setup(TensorShape shape, DataType data_type, QuantizationInfo input_quantization_info, QuantizationInfo output_quantization_info, DataLayout data_layout, InterpolationPolicy policy, + BorderMode border_mode, SamplingPolicy sampling_policy, + bool align_corners) + { + ScaleValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, + data_type, + input_quantization_info, + data_layout, + policy, + border_mode, + sampling_policy, + align_corners, + mixed_layout, + output_quantization_info); + } +}; +template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false> class ScaleValidationFixture : public ScaleValidationGenericFixture<TensorType, AccessorType, FunctionType, T> { public: - template <typename...> void setup(TensorShape shape, DataType data_type, DataLayout data_layout, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy, bool align_corners) { ScaleValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, @@ -215,10 +259,12 @@ public: policy, border_mode, sampling_policy, - align_corners); + align_corners, + mixed_layout, + QuantizationInfo()); } }; } // namespace validation } // namespace test } // namespace arm_compute -#endif /* ARM_COMPUTE_TEST_SCALE_FIXTURE */ +#endif // ACL_TESTS_VALIDATION_FIXTURES_SCALEFIXTURE_H |