aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2019-01-03 11:10:25 +0000
committerMichalis Spyrou <michalis.spyrou@arm.com>2019-01-10 16:24:26 +0000
commitaea14c63e2efeda9d5f7492099389d439c65204f (patch)
tree176a6181bbf00e4df078d5da0a17dd44f248958e /tests/validation/fixtures
parentc10bc0b5db5169a6ccea02a1aaefe34f082709e5 (diff)
downloadComputeLibrary-aea14c63e2efeda9d5f7492099389d439c65204f.tar.gz
COMPMID-1764 NEON: Implement ArgMax/ArgMin
Change-Id: Ibe23aa90b36ffd8553d1d1c35fada5d300fab829 Reviewed-on: https://review.mlplatform.org/475 Reviewed-by: Isabella Gottardi <isabella.gottardi@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Giuseppe Rossini <giuseppe.rossini@arm.com>
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r--tests/validation/fixtures/ArgMinMaxFixture.h60
-rw-r--r--tests/validation/fixtures/ReduceMeanFixture.h4
-rw-r--r--tests/validation/fixtures/ReductionOperationFixture.h4
3 files changed, 50 insertions, 18 deletions
diff --git a/tests/validation/fixtures/ArgMinMaxFixture.h b/tests/validation/fixtures/ArgMinMaxFixture.h
index 5f5f85c104..e263b25bf2 100644
--- a/tests/validation/fixtures/ArgMinMaxFixture.h
+++ b/tests/validation/fixtures/ArgMinMaxFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -42,28 +42,38 @@ namespace test
namespace validation
{
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class ArgMinMaxValidationFixture : public framework::Fixture
+class ArgMinMaxValidationBaseFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape shape, DataType data_type, int axis, ReductionOperation op)
+ void setup(TensorShape shape, DataType data_type, int axis, ReductionOperation op, QuantizationInfo q_info)
{
- _target = compute_target(shape, data_type, axis, op);
- _reference = compute_reference(shape, data_type, axis, op);
+ _target = compute_target(shape, data_type, axis, op, q_info);
+ _reference = compute_reference(shape, data_type, axis, op, q_info);
}
protected:
template <typename U>
void fill(U &&tensor)
{
- std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
- library->fill(tensor, distribution, 0);
+ if(!is_data_type_quantized(tensor.data_type()))
+ {
+ std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
+ library->fill(tensor, distribution, 0);
+ }
+ else
+ {
+ std::pair<int, int> bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f);
+ std::uniform_int_distribution<uint8_t> distribution(bounds.first, bounds.second);
+
+ library->fill(tensor, distribution, 0);
+ }
}
- TensorType compute_target(TensorShape &src_shape, DataType data_type, int axis, ReductionOperation op)
+ TensorType compute_target(TensorShape &src_shape, DataType data_type, int axis, ReductionOperation op, QuantizationInfo q_info)
{
// Create tensors
- TensorType src = create_tensor<TensorType>(src_shape, data_type, 1);
+ TensorType src = create_tensor<TensorType>(src_shape, data_type, 1, q_info);
TensorType dst;
// Create and configure function
@@ -89,21 +99,43 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(TensorShape &src_shape, DataType data_type, int axis, ReductionOperation op)
+ SimpleTensor<uint32_t> compute_reference(TensorShape &src_shape, DataType data_type, int axis, ReductionOperation op, QuantizationInfo q_info)
{
// Create reference
- SimpleTensor<T> src{ src_shape, data_type, 1 };
+ SimpleTensor<T> src{ src_shape, data_type, 1, q_info };
// Fill reference
fill(src);
TensorShape output_shape = src_shape;
output_shape.set(axis, 1);
- return reference::reduction_operation<T>(src, output_shape, axis, op);
+ return reference::reduction_operation<T, uint32_t>(src, output_shape, axis, op);
}
- TensorType _target{};
- SimpleTensor<T> _reference{};
+ TensorType _target{};
+ SimpleTensor<uint32_t> _reference{};
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class ArgMinMaxValidationQuantizedFixture : public ArgMinMaxValidationBaseFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+ template <typename...>
+ void setup(const TensorShape &shape, DataType data_type, int axis, ReductionOperation op, QuantizationInfo quantization_info)
+ {
+ ArgMinMaxValidationBaseFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, op, quantization_info);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class ArgMinMaxValidationFixture : public ArgMinMaxValidationBaseFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+ template <typename...>
+ void setup(const TensorShape &shape, DataType data_type, int axis, ReductionOperation op)
+ {
+ ArgMinMaxValidationBaseFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, op, QuantizationInfo());
+ }
};
} // namespace validation
} // namespace test
diff --git a/tests/validation/fixtures/ReduceMeanFixture.h b/tests/validation/fixtures/ReduceMeanFixture.h
index 769d7f674f..44bb9fca6a 100644
--- a/tests/validation/fixtures/ReduceMeanFixture.h
+++ b/tests/validation/fixtures/ReduceMeanFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -113,7 +113,7 @@ protected:
{
TensorShape output_shape = i == 0 ? src_shape : out.shape();
output_shape.set(axis[i], 1);
- out = reference::reduction_operation<T>(i == 0 ? src : out, output_shape, axis[i], ReductionOperation::MEAN_SUM);
+ out = reference::reduction_operation<T, T>(i == 0 ? src : out, output_shape, axis[i], ReductionOperation::MEAN_SUM);
}
if(!keep_dims)
diff --git a/tests/validation/fixtures/ReductionOperationFixture.h b/tests/validation/fixtures/ReductionOperationFixture.h
index 9079b47cbb..d01f41abf0 100644
--- a/tests/validation/fixtures/ReductionOperationFixture.h
+++ b/tests/validation/fixtures/ReductionOperationFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -107,7 +107,7 @@ protected:
// Fill reference
fill(src);
- return reference::reduction_operation<T>(src, dst_shape, axis, op);
+ return reference::reduction_operation<T, T>(src, dst_shape, axis, op);
}
TensorType _target{};