diff options
Diffstat (limited to 'tests/validation/fixtures/DequantizationLayerFixture.h')
-rw-r--r-- | tests/validation/fixtures/DequantizationLayerFixture.h | 87 |
1 files changed, 13 insertions, 74 deletions
diff --git a/tests/validation/fixtures/DequantizationLayerFixture.h b/tests/validation/fixtures/DequantizationLayerFixture.h index 0bf3522cd6..2e3712dff2 100644 --- a/tests/validation/fixtures/DequantizationLayerFixture.h +++ b/tests/validation/fixtures/DequantizationLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -47,10 +47,10 @@ class DequantizationValidationFixture : public framework::Fixture { public: template <typename...> - void setup(TensorShape shape, DataType data_type) + void setup(TensorShape shape, DataType data_type, QuantizationInfo qinfo) { - _target = compute_target(shape, data_type); - _reference = compute_reference(shape, data_type); + _target = compute_target(shape, data_type, qinfo); + _reference = compute_reference(shape, data_type, qinfo); } protected: @@ -60,80 +60,28 @@ protected: library->fill_tensor_uniform(tensor, 0); } - template <typename U> - void fill_min_max(U &&tensor) - { - std::mt19937 gen(library->seed()); - std::uniform_real_distribution<float> distribution(-1.0f, 1.0f); - - Window window; - - window.set(0, Window::Dimension(0, tensor.shape()[0], 2)); - - for(unsigned int d = 1; d < tensor.shape().num_dimensions(); ++d) - { - window.set(d, Window::Dimension(0, tensor.shape()[d], 1)); - } - - execute_window_loop(window, [&](const Coordinates & id) - { - const float n1 = distribution(gen); - const float n2 = distribution(gen); - - float min = 0.0f; - float max = 0.0f; - - if(n1 < n2) - { - min = n1; - max = n2; - } - else - { - min = n2; - max = n1; - } - - auto out_ptr = reinterpret_cast<float *>(tensor(id)); - out_ptr[0] = min; - out_ptr[1] = max; - }); - } - - TensorType compute_target(const TensorShape &shape, DataType data_type) + TensorType compute_target(const TensorShape &shape, DataType data_type, QuantizationInfo qinfo) { - TensorShape shape_min_max = shape; - shape_min_max.set(Window::DimX, 2); - - // Remove Y and Z dimensions and keep the batches - shape_min_max.remove_dimension(1); - shape_min_max.remove_dimension(1); - // Create tensors - TensorType src = create_tensor<TensorType>(shape, data_type); - TensorType dst = create_tensor<TensorType>(shape, DataType::F32); - TensorType min_max = create_tensor<TensorType>(shape_min_max, DataType::F32); + TensorType src = create_tensor<TensorType>(shape, DataType::QASYMM8, 1, qinfo); + TensorType dst = create_tensor<TensorType>(shape, data_type); // Create and configure function FunctionType dequantization_layer; - dequantization_layer.configure(&src, &dst, &min_max); + dequantization_layer.configure(&src, &dst); ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(min_max.info()->is_resizable(), framework::LogLevel::ERRORS); // Allocate tensors src.allocator()->allocate(); dst.allocator()->allocate(); - min_max.allocator()->allocate(); ARM_COMPUTE_EXPECT(!src.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(!min_max.info()->is_resizable(), framework::LogLevel::ERRORS); // Fill tensors fill(AccessorType(src)); - fill_min_max(AccessorType(min_max)); // Compute function dequantization_layer.run(); @@ -141,28 +89,19 @@ protected: return dst; } - SimpleTensor<float> compute_reference(const TensorShape &shape, DataType data_type) + SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type, QuantizationInfo qinfo) { - TensorShape shape_min_max = shape; - shape_min_max.set(Window::DimX, 2); - - // Remove Y and Z dimensions and keep the batches - shape_min_max.remove_dimension(1); - shape_min_max.remove_dimension(1); - // Create reference - SimpleTensor<T> src{ shape, data_type }; - SimpleTensor<float> min_max{ shape_min_max, data_type }; + SimpleTensor<uint8_t> src{ shape, DataType::QASYMM8, 1, qinfo }; // Fill reference fill(src); - fill_min_max(min_max); - return reference::dequantization_layer<T>(src, min_max); + return reference::dequantization_layer<T>(src); } - TensorType _target{}; - SimpleTensor<float> _reference{}; + TensorType _target{}; + SimpleTensor<T> _reference{}; }; } // namespace validation } // namespace test |