From aea14c63e2efeda9d5f7492099389d439c65204f Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Thu, 3 Jan 2019 11:10:25 +0000 Subject: COMPMID-1764 NEON: Implement ArgMax/ArgMin Change-Id: Ibe23aa90b36ffd8553d1d1c35fada5d300fab829 Reviewed-on: https://review.mlplatform.org/475 Reviewed-by: Isabella Gottardi Tested-by: Arm Jenkins Reviewed-by: Giuseppe Rossini --- tests/validation/fixtures/ArgMinMaxFixture.h | 60 +++++++++++++++++----- tests/validation/fixtures/ReduceMeanFixture.h | 4 +- .../fixtures/ReductionOperationFixture.h | 4 +- 3 files changed, 50 insertions(+), 18 deletions(-) (limited to 'tests/validation/fixtures') diff --git a/tests/validation/fixtures/ArgMinMaxFixture.h b/tests/validation/fixtures/ArgMinMaxFixture.h index 5f5f85c104..e263b25bf2 100644 --- a/tests/validation/fixtures/ArgMinMaxFixture.h +++ b/tests/validation/fixtures/ArgMinMaxFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -42,28 +42,38 @@ namespace test namespace validation { template -class ArgMinMaxValidationFixture : public framework::Fixture +class ArgMinMaxValidationBaseFixture : public framework::Fixture { public: template - void setup(TensorShape shape, DataType data_type, int axis, ReductionOperation op) + void setup(TensorShape shape, DataType data_type, int axis, ReductionOperation op, QuantizationInfo q_info) { - _target = compute_target(shape, data_type, axis, op); - _reference = compute_reference(shape, data_type, axis, op); + _target = compute_target(shape, data_type, axis, op, q_info); + _reference = compute_reference(shape, data_type, axis, op, q_info); } protected: template void fill(U &&tensor) { - std::uniform_real_distribution<> distribution(-1.0f, 1.0f); - library->fill(tensor, distribution, 0); + if(!is_data_type_quantized(tensor.data_type())) + { + std::uniform_real_distribution<> distribution(-1.0f, 1.0f); + library->fill(tensor, distribution, 0); + } + else + { + std::pair bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f); + std::uniform_int_distribution distribution(bounds.first, bounds.second); + + library->fill(tensor, distribution, 0); + } } - TensorType compute_target(TensorShape &src_shape, DataType data_type, int axis, ReductionOperation op) + TensorType compute_target(TensorShape &src_shape, DataType data_type, int axis, ReductionOperation op, QuantizationInfo q_info) { // Create tensors - TensorType src = create_tensor(src_shape, data_type, 1); + TensorType src = create_tensor(src_shape, data_type, 1, q_info); TensorType dst; // Create and configure function @@ -89,21 +99,43 @@ protected: return dst; } - SimpleTensor compute_reference(TensorShape &src_shape, DataType data_type, int axis, ReductionOperation op) + SimpleTensor compute_reference(TensorShape &src_shape, DataType data_type, int axis, ReductionOperation op, QuantizationInfo q_info) { // Create reference - SimpleTensor src{ src_shape, data_type, 1 }; + SimpleTensor src{ src_shape, data_type, 1, q_info }; // Fill reference fill(src); TensorShape output_shape = src_shape; output_shape.set(axis, 1); - return reference::reduction_operation(src, output_shape, axis, op); + return reference::reduction_operation(src, output_shape, axis, op); } - TensorType _target{}; - SimpleTensor _reference{}; + TensorType _target{}; + SimpleTensor _reference{}; +}; + +template +class ArgMinMaxValidationQuantizedFixture : public ArgMinMaxValidationBaseFixture +{ +public: + template + void setup(const TensorShape &shape, DataType data_type, int axis, ReductionOperation op, QuantizationInfo quantization_info) + { + ArgMinMaxValidationBaseFixture::setup(shape, data_type, axis, op, quantization_info); + } +}; + +template +class ArgMinMaxValidationFixture : public ArgMinMaxValidationBaseFixture +{ +public: + template + void setup(const TensorShape &shape, DataType data_type, int axis, ReductionOperation op) + { + ArgMinMaxValidationBaseFixture::setup(shape, data_type, axis, op, QuantizationInfo()); + } }; } // namespace validation } // namespace test diff --git a/tests/validation/fixtures/ReduceMeanFixture.h b/tests/validation/fixtures/ReduceMeanFixture.h index 769d7f674f..44bb9fca6a 100644 --- a/tests/validation/fixtures/ReduceMeanFixture.h +++ b/tests/validation/fixtures/ReduceMeanFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -113,7 +113,7 @@ protected: { TensorShape output_shape = i == 0 ? src_shape : out.shape(); output_shape.set(axis[i], 1); - out = reference::reduction_operation(i == 0 ? src : out, output_shape, axis[i], ReductionOperation::MEAN_SUM); + out = reference::reduction_operation(i == 0 ? src : out, output_shape, axis[i], ReductionOperation::MEAN_SUM); } if(!keep_dims) diff --git a/tests/validation/fixtures/ReductionOperationFixture.h b/tests/validation/fixtures/ReductionOperationFixture.h index 9079b47cbb..d01f41abf0 100644 --- a/tests/validation/fixtures/ReductionOperationFixture.h +++ b/tests/validation/fixtures/ReductionOperationFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -107,7 +107,7 @@ protected: // Fill reference fill(src); - return reference::reduction_operation(src, dst_shape, axis, op); + return reference::reduction_operation(src, dst_shape, axis, op); } TensorType _target{}; -- cgit v1.2.1