aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ReverseFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/ReverseFixture.h')
-rw-r--r--tests/validation/fixtures/ReverseFixture.h52
1 files changed, 31 insertions, 21 deletions
diff --git a/tests/validation/fixtures/ReverseFixture.h b/tests/validation/fixtures/ReverseFixture.h
index 9d047a0067..856bff7b12 100644
--- a/tests/validation/fixtures/ReverseFixture.h
+++ b/tests/validation/fixtures/ReverseFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2021 Arm Limited.
+ * Copyright (c) 2018-2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_TEST_REVERSE_FIXTURE
-#define ARM_COMPUTE_TEST_REVERSE_FIXTURE
+#ifndef ACL_TESTS_VALIDATION_FIXTURES_REVERSEFIXTURE_H
+#define ACL_TESTS_VALIDATION_FIXTURES_REVERSEFIXTURE_H
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/TensorShape.h"
@@ -45,11 +45,11 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ
class ReverseValidationFixture : public framework::Fixture
{
public:
- template <typename...>
- void setup(TensorShape shape, TensorShape axis_shape, DataType data_type)
+ void setup(TensorShape shape, TensorShape axis_shape, DataType data_type, bool use_negative_axis = false, bool use_inverted_axis = false)
{
- _target = compute_target(shape, axis_shape, data_type);
- _reference = compute_reference(shape, axis_shape, data_type);
+ _num_dims = shape.num_dimensions();
+ _target = compute_target(shape, axis_shape, data_type, use_negative_axis, use_inverted_axis);
+ _reference = compute_reference(shape, axis_shape, data_type, use_negative_axis, use_inverted_axis);
}
protected:
@@ -58,16 +58,25 @@ protected:
{
library->fill_tensor_uniform(tensor, 0);
}
- std::vector<int> generate_random_axis()
+ std::vector<int32_t> generate_random_axis(bool use_negative = false)
{
- std::vector<int> axis_v = { 0, 1, 2, 3 };
- std::mt19937 g(0);
+ std::vector<int32_t> axis_v;
+ if(use_negative)
+ {
+ axis_v = { -1, -2, -3, -4 };
+ }
+ else
+ {
+ axis_v = { 0, 1, 2, 3 };
+ }
+ axis_v = std::vector<int32_t>(axis_v.begin(), axis_v.begin() + _num_dims);
+ std::mt19937 g(library->seed());
std::shuffle(axis_v.begin(), axis_v.end(), g);
return axis_v;
}
- TensorType compute_target(const TensorShape &shape, const TensorShape &axis_shape, DataType data_type)
+ TensorType compute_target(const TensorShape &shape, const TensorShape &axis_shape, DataType data_type, bool use_negative_axis, bool use_inverted_axis = false)
{
// Create tensors
TensorType src = create_tensor<TensorType>(shape, data_type, 1);
@@ -76,7 +85,7 @@ protected:
// Create and configure function
FunctionType reverse_func;
- reverse_func.configure(&src, &dst, &axis);
+ reverse_func.configure(&src, &dst, &axis, use_inverted_axis);
ARM_COMPUTE_ASSERT(src.info()->is_resizable());
ARM_COMPUTE_ASSERT(axis.info()->is_resizable());
@@ -95,8 +104,8 @@ protected:
fill(AccessorType(src));
{
auto axis_data = AccessorType(axis);
- auto axis_v = generate_random_axis();
- std::copy(axis_v.begin(), axis_v.begin() + axis_shape.x(), static_cast<int32_t *>(axis_data.data()));
+ auto axis_v = generate_random_axis(use_negative_axis);
+ std::copy(axis_v.begin(), axis_v.begin() + axis_shape.total_size(), static_cast<int32_t *>(axis_data.data()));
}
// Compute function
@@ -105,24 +114,25 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(const TensorShape &shape, const TensorShape &axis_shape, DataType data_type)
+ SimpleTensor<T> compute_reference(const TensorShape &shape, const TensorShape &axis_shape, DataType data_type, bool use_negative_axis, bool use_inverted_axis = false)
{
// Create reference
- SimpleTensor<T> src{ shape, data_type };
- SimpleTensor<uint32_t> axis{ axis_shape, DataType::U32 };
+ SimpleTensor<T> src{ shape, data_type };
+ SimpleTensor<int32_t> axis{ axis_shape, DataType::S32 };
// Fill reference
fill(src);
- auto axis_v = generate_random_axis();
- std::copy(axis_v.begin(), axis_v.begin() + axis_shape.x(), axis.data());
+ auto axis_v = generate_random_axis(use_negative_axis);
+ std::copy(axis_v.begin(), axis_v.begin() + axis_shape.total_size(), axis.data());
- return reference::reverse<T>(src, axis);
+ return reference::reverse<T>(src, axis, use_inverted_axis);
}
TensorType _target{};
SimpleTensor<T> _reference{};
+ unsigned int _num_dims{};
};
} // namespace validation
} // namespace test
} // namespace arm_compute
-#endif /* ARM_COMPUTE_TEST_REVERSE_FIXTURE */
+#endif // ACL_TESTS_VALIDATION_FIXTURES_REVERSEFIXTURE_H