aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorAdnan AlSinan <adnan.alsinan@arm.com>2023-09-18 14:49:45 +0100
committerAdnan AlSinan <adnan.alsinan@arm.com>2023-09-27 15:13:29 +0000
commitbdcb4c148ee2fdeaaddf4cf1e57bbb0de02bb894 (patch)
treeee9743ddfe42b800bbc54dc3c273c188cb779017 /tests
parent729099c5d134c2c34459a2bdbd5453ad4ca68cac (diff)
downloadComputeLibrary-bdcb4c148ee2fdeaaddf4cf1e57bbb0de02bb894.tar.gz
Implement tflite compliant reverse for CPU
- Add support for negative axis values. - Add option to use opposite ACL convention for dimension addressing. - Add validation tests for the mentioned additions. Resolves COMPMID-6497 Change-Id: I9174b201c3adc070766cc6cffcbe4ec1fe5ec1c3 Signed-off-by: Adnan AlSinan <adnan.alsinan@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10335 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: SiCong Li <sicong.li@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/validation/CL/Reverse.cpp47
-rw-r--r--tests/validation/NEON/Reverse.cpp50
-rw-r--r--tests/validation/fixtures/ReverseFixture.h74
-rw-r--r--tests/validation/reference/DFT.cpp8
-rw-r--r--tests/validation/reference/Reverse.cpp35
-rw-r--r--tests/validation/reference/Reverse.h10
6 files changed, 170 insertions, 54 deletions
diff --git a/tests/validation/CL/Reverse.cpp b/tests/validation/CL/Reverse.cpp
index 11df0e7803..ff46ba64ad 100644
--- a/tests/validation/CL/Reverse.cpp
+++ b/tests/validation/CL/Reverse.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2020, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -41,6 +41,7 @@ namespace test
{
namespace validation
{
+using framework::dataset::make;
namespace
{
auto run_small_dataset = combine(datasets::SmallShapes(), datasets::Tiny1DShapes());
@@ -53,28 +54,28 @@ TEST_SUITE(Reverse)
// *INDENT-OFF*
// clang-format off
DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
- framework::dataset::make("InputInfo", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S8), // Invalid axis datatype
+ make("InputInfo", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S8), // Invalid axis datatype
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid axis shape
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid axis length (> 4)
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Mismatching shapes
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(2U), 1, DataType::U8),
}),
- framework::dataset::make("OutputInfo", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S8),
+ make("OutputInfo", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S8),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(2U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(2U), 1, DataType::U8),
})),
- framework::dataset::make("AxisInfo",{ TensorInfo(TensorShape(3U), 1, DataType::U8),
+ make("AxisInfo",{ TensorInfo(TensorShape(3U), 1, DataType::U8),
TensorInfo(TensorShape(2U, 10U), 1, DataType::U32),
TensorInfo(TensorShape(8U), 1, DataType::U32),
TensorInfo(TensorShape(2U), 1, DataType::U32),
TensorInfo(TensorShape(2U), 1, DataType::U32),
TensorInfo(TensorShape(2U), 1, DataType::U32),
})),
- framework::dataset::make("Expected", { false, false, false, false, true, true})),
+ make("Expected", { false, false, false, false, true, true})),
src_info, dst_info, axis_info, expected)
{
Status s = CLReverse::validate(&src_info.clone()->set_is_resizable(false),
@@ -93,7 +94,11 @@ TEST_SUITE(F16)
FIXTURE_DATA_TEST_CASE(RunSmall,
CLReverseFixture<half>,
framework::DatasetMode::PRECOMMIT,
- combine(run_small_dataset, framework::dataset::make("DataType", DataType::F16)))
+ combine(
+ run_small_dataset,
+ make("DataType", DataType::F16),
+ make("use_negative_axis", { false }),
+ make("use_inverted_axis", { false })))
{
// Validate output
validate(CLAccessor(_target), _reference);
@@ -102,7 +107,11 @@ FIXTURE_DATA_TEST_CASE(RunSmall,
FIXTURE_DATA_TEST_CASE(RunLarge,
CLReverseFixture<half>,
framework::DatasetMode::NIGHTLY,
- combine(run_large_dataset, framework::dataset::make("DataType", DataType::F16)))
+ combine(
+ run_large_dataset,
+ make("DataType", DataType::F16),
+ make("use_negative_axis", { false }),
+ make("use_inverted_axis", { false })))
{
// Validate output
validate(CLAccessor(_target), _reference);
@@ -113,7 +122,11 @@ TEST_SUITE(FP32)
FIXTURE_DATA_TEST_CASE(RunSmall,
CLReverseFixture<float>,
framework::DatasetMode::PRECOMMIT,
- combine(run_small_dataset, framework::dataset::make("DataType", DataType::F32)))
+ combine(
+ run_small_dataset,
+ make("DataType", DataType::F32),
+ make("use_negative_axis", { false }),
+ make("use_inverted_axis", { false })))
{
// Validate output
validate(CLAccessor(_target), _reference);
@@ -122,7 +135,11 @@ FIXTURE_DATA_TEST_CASE(RunSmall,
FIXTURE_DATA_TEST_CASE(RunLarge,
CLReverseFixture<float>,
framework::DatasetMode::NIGHTLY,
- combine(run_large_dataset, framework::dataset::make("DataType", DataType::F32)))
+ combine(
+ run_large_dataset,
+ make("DataType", DataType::F32),
+ make("use_negative_axis", { false }),
+ make("use_inverted_axis", { false })))
{
// Validate output
validate(CLAccessor(_target), _reference);
@@ -135,7 +152,11 @@ TEST_SUITE(QASYMM8)
FIXTURE_DATA_TEST_CASE(RunSmall,
CLReverseFixture<uint8_t>,
framework::DatasetMode::PRECOMMIT,
- combine(run_small_dataset, framework::dataset::make("DataType", DataType::QASYMM8)))
+ combine(
+ run_small_dataset,
+ make("DataType", DataType::QASYMM8),
+ make("use_negative_axis", { false }),
+ make("use_inverted_axis", { false })))
{
// Validate output
validate(CLAccessor(_target), _reference);
@@ -144,7 +165,11 @@ FIXTURE_DATA_TEST_CASE(RunSmall,
FIXTURE_DATA_TEST_CASE(RunLarge,
CLReverseFixture<uint8_t>,
framework::DatasetMode::NIGHTLY,
- combine(run_large_dataset, framework::dataset::make("DataType", DataType::QASYMM8)))
+ combine(
+ run_large_dataset,
+ make("DataType", DataType::QASYMM8),
+ make("use_negative_axis", { false }),
+ make("use_inverted_axis", { false })))
{
// Validate output
validate(CLAccessor(_target), _reference);
diff --git a/tests/validation/NEON/Reverse.cpp b/tests/validation/NEON/Reverse.cpp
index 3dc3eeee80..eb0c995ad9 100644
--- a/tests/validation/NEON/Reverse.cpp
+++ b/tests/validation/NEON/Reverse.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2021 Arm Limited.
+ * Copyright (c) 2018-2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -43,6 +43,7 @@ namespace validation
{
namespace
{
+using framework::dataset::make;
auto run_small_dataset = combine(datasets::SmallShapes(), datasets::Tiny1DShapes());
auto run_large_dataset = combine(datasets::LargeShapes(), datasets::Tiny1DShapes());
@@ -53,28 +54,31 @@ TEST_SUITE(Reverse)
// *INDENT-OFF*
// clang-format off
DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
- framework::dataset::make("InputInfo", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S8), // Invalid axis datatype
+ make("InputInfo", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S8), // Invalid axis datatype
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid axis shape
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid axis length (> 4)
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Mismatching shapes
+ TensorInfo(TensorShape(32U, 13U, 17U, 3U, 2U), 1, DataType::U8), // Unsupported source dimensions (>4)
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(2U), 1, DataType::U8),
}),
- framework::dataset::make("OutputInfo", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S8),
+ make("OutputInfo", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S8),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(2U, 13U, 2U), 1, DataType::U8),
+ TensorInfo(TensorShape(32U, 13U, 17U, 3U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(2U), 1, DataType::U8),
})),
- framework::dataset::make("AxisInfo", { TensorInfo(TensorShape(3U), 1, DataType::U8),
+ make("AxisInfo", { TensorInfo(TensorShape(3U), 1, DataType::U8),
TensorInfo(TensorShape(2U, 10U), 1, DataType::U32),
TensorInfo(TensorShape(8U), 1, DataType::U32),
TensorInfo(TensorShape(2U), 1, DataType::U32),
TensorInfo(TensorShape(2U), 1, DataType::U32),
TensorInfo(TensorShape(2U), 1, DataType::U32),
+ TensorInfo(TensorShape(2U), 1, DataType::U32),
})),
- framework::dataset::make("Expected", { false, false, false, false, true, true})),
+ make("Expected", { false, false, false, false, false, true, true})),
src_info, dst_info, axis_info, expected)
{
Status s = NEReverse::validate(&src_info.clone()->set_is_resizable(false),
@@ -95,7 +99,11 @@ TEST_SUITE(F16)
FIXTURE_DATA_TEST_CASE(RunSmall,
NEReverseFixture<half>,
framework::DatasetMode::PRECOMMIT,
- combine(run_small_dataset, framework::dataset::make("DataType", DataType::F16)))
+ combine(
+ run_small_dataset,
+ make("DataType", DataType::F16),
+ make("use_negative_axis", { true, false }),
+ make("use_inverted_axis", { true, false })))
{
// Validate output
validate(Accessor(_target), _reference);
@@ -104,7 +112,11 @@ FIXTURE_DATA_TEST_CASE(RunSmall,
FIXTURE_DATA_TEST_CASE(RunLarge,
NEReverseFixture<half>,
framework::DatasetMode::NIGHTLY,
- combine(run_large_dataset, framework::dataset::make("DataType", DataType::F16)))
+ combine(
+ run_large_dataset,
+ make("DataType", DataType::F16),
+ make("use_negative_axis", { true, false }),
+ make("use_inverted_axis", { true, false })))
{
// Validate output
validate(Accessor(_target), _reference);
@@ -116,7 +128,11 @@ TEST_SUITE(FP32)
FIXTURE_DATA_TEST_CASE(RunSmall,
NEReverseFixture<float>,
framework::DatasetMode::PRECOMMIT,
- combine(run_small_dataset, framework::dataset::make("DataType", DataType::F32)))
+ combine(
+ run_small_dataset,
+ make("DataType", DataType::F32),
+ make("use_negative_axis", { true, false }),
+ make("use_inverted_axis", { true, false })))
{
// Validate output
validate(Accessor(_target), _reference);
@@ -125,7 +141,11 @@ FIXTURE_DATA_TEST_CASE(RunSmall,
FIXTURE_DATA_TEST_CASE(RunLarge,
NEReverseFixture<float>,
framework::DatasetMode::NIGHTLY,
- combine(run_large_dataset, framework::dataset::make("DataType", DataType::F32)))
+ combine(
+ run_large_dataset,
+ make("DataType", DataType::F32),
+ make("use_negative_axis", { true, false }),
+ make("use_inverted_axis", { true, false })))
{
// Validate output
validate(Accessor(_target), _reference);
@@ -138,7 +158,11 @@ TEST_SUITE(QASYMM8)
FIXTURE_DATA_TEST_CASE(RunSmall,
NEReverseFixture<uint8_t>,
framework::DatasetMode::PRECOMMIT,
- combine(run_small_dataset, framework::dataset::make("DataType", DataType::QASYMM8)))
+ combine(
+ run_small_dataset,
+ make("DataType", DataType::QASYMM8),
+ make("use_negative_axis", { true, false }),
+ make("use_inverted_axis", { true, false })))
{
// Validate output
validate(Accessor(_target), _reference);
@@ -147,7 +171,11 @@ FIXTURE_DATA_TEST_CASE(RunSmall,
FIXTURE_DATA_TEST_CASE(RunLarge,
NEReverseFixture<uint8_t>,
framework::DatasetMode::NIGHTLY,
- combine(run_large_dataset, framework::dataset::make("DataType", DataType::QASYMM8)))
+ combine(
+ run_large_dataset,
+ make("DataType", DataType::QASYMM8),
+ make("use_negative_axis", { true, false }),
+ make("use_inverted_axis", { true, false })))
{
// Validate output
validate(Accessor(_target), _reference);
diff --git a/tests/validation/fixtures/ReverseFixture.h b/tests/validation/fixtures/ReverseFixture.h
index 509fd93abf..8ff8cf9421 100644
--- a/tests/validation/fixtures/ReverseFixture.h
+++ b/tests/validation/fixtures/ReverseFixture.h
@@ -21,12 +21,15 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_TEST_REVERSE_FIXTURE
-#define ARM_COMPUTE_TEST_REVERSE_FIXTURE
+#ifndef ACL_TESTS_VALIDATION_FIXTURES_REVERSEFIXTURE_H
+#define ACL_TESTS_VALIDATION_FIXTURES_REVERSEFIXTURE_H
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/TensorShape.h"
#include "arm_compute/core/Types.h"
+#ifdef ARM_COMPUTE_OPENCL_ENABLED
+#include "arm_compute/runtime/CL/functions/CLReverse.h"
+#endif // ARM_COMPUTE_OPENCL_ENABLED
#include "arm_compute/runtime/Tensor.h"
#include "tests/AssetsLibrary.h"
#include "tests/Globals.h"
@@ -41,14 +44,40 @@ namespace test
{
namespace validation
{
+namespace
+{
+template <typename ReverseFunction, typename TensorType>
+#ifdef ARM_COMPUTE_OPENCL_ENABLED
+std::enable_if_t < !std::is_same<ReverseFunction, CLReverse>::value, void >
+#else // ARM_COMPUTE_OPENCL_ENABLED
+void
+#endif // ARM_COMPUTE_OPENCL_ENABLED
+configureReverse(ReverseFunction &func, TensorType &src, TensorType &axis, TensorType &dst, bool use_inverted_axis)
+{
+ func.configure(&src, &dst, &axis, use_inverted_axis);
+}
+
+#ifdef ARM_COMPUTE_OPENCL_ENABLED
+template <typename ReverseFunction, typename TensorType>
+std::enable_if_t<std::is_same<ReverseFunction, CLReverse>::value, void>
+configureReverse(ReverseFunction &func, TensorType &src, TensorType &axis, TensorType &dst, bool use_inverted_axis)
+{
+ ARM_COMPUTE_UNUSED(use_inverted_axis);
+ func.configure(&src, &dst, &axis);
+}
+
+#endif // ARM_COMPUTE_OPENCL_ENABLED
+} //namespace
+
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
class ReverseValidationFixture : public framework::Fixture
{
public:
- void setup(TensorShape shape, TensorShape axis_shape, DataType data_type)
+ void setup(TensorShape shape, TensorShape axis_shape, DataType data_type, bool use_negative_axis = false, bool use_inverted_axis = false)
{
- _target = compute_target(shape, axis_shape, data_type);
- _reference = compute_reference(shape, axis_shape, data_type);
+ _num_dims = shape.num_dimensions();
+ _target = compute_target(shape, axis_shape, data_type, use_negative_axis, use_inverted_axis);
+ _reference = compute_reference(shape, axis_shape, data_type, use_negative_axis, use_inverted_axis);
}
protected:
@@ -57,16 +86,25 @@ protected:
{
library->fill_tensor_uniform(tensor, 0);
}
- std::vector<int> generate_random_axis()
+ std::vector<int> generate_random_axis(bool use_negative = false)
{
- std::vector<int> axis_v = { 0, 1, 2, 3 };
- std::mt19937 g(0);
+ std::vector<int> axis_v;
+ if(use_negative)
+ {
+ axis_v = { -1, -2, -3, -4 };
+ }
+ else
+ {
+ axis_v = { 0, 1, 2, 3 };
+ }
+ axis_v = std::vector<int>(axis_v.begin(), axis_v.begin() + _num_dims);
+ std::mt19937 g(library->seed());
std::shuffle(axis_v.begin(), axis_v.end(), g);
return axis_v;
}
- TensorType compute_target(const TensorShape &shape, const TensorShape &axis_shape, DataType data_type)
+ TensorType compute_target(const TensorShape &shape, const TensorShape &axis_shape, DataType data_type, bool use_negative_axis, bool use_inverted_axis = false)
{
// Create tensors
TensorType src = create_tensor<TensorType>(shape, data_type, 1);
@@ -75,7 +113,8 @@ protected:
// Create and configure function
FunctionType reverse_func;
- reverse_func.configure(&src, &dst, &axis);
+
+ configureReverse(reverse_func, src, axis, dst, use_inverted_axis);
ARM_COMPUTE_ASSERT(src.info()->is_resizable());
ARM_COMPUTE_ASSERT(axis.info()->is_resizable());
@@ -94,7 +133,7 @@ protected:
fill(AccessorType(src));
{
auto axis_data = AccessorType(axis);
- auto axis_v = generate_random_axis();
+ auto axis_v = generate_random_axis(use_negative_axis);
std::copy(axis_v.begin(), axis_v.begin() + axis_shape.x(), static_cast<int32_t *>(axis_data.data()));
}
@@ -104,24 +143,25 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(const TensorShape &shape, const TensorShape &axis_shape, DataType data_type)
+ SimpleTensor<T> compute_reference(const TensorShape &shape, const TensorShape &axis_shape, DataType data_type, bool use_negative_axis, bool use_inverted_axis = false)
{
// Create reference
- SimpleTensor<T> src{ shape, data_type };
- SimpleTensor<uint32_t> axis{ axis_shape, DataType::U32 };
+ SimpleTensor<T> src{ shape, data_type };
+ SimpleTensor<int32_t> axis{ axis_shape, DataType::S32 };
// Fill reference
fill(src);
- auto axis_v = generate_random_axis();
+ auto axis_v = generate_random_axis(use_negative_axis);
std::copy(axis_v.begin(), axis_v.begin() + axis_shape.x(), axis.data());
- return reference::reverse<T>(src, axis);
+ return reference::reverse<T>(src, axis, use_inverted_axis);
}
TensorType _target{};
SimpleTensor<T> _reference{};
+ unsigned int _num_dims{};
};
} // namespace validation
} // namespace test
} // namespace arm_compute
-#endif /* ARM_COMPUTE_TEST_REVERSE_FIXTURE */
+#endif // ACL_TESTS_VALIDATION_FIXTURES_REVERSEFIXTURE_H
diff --git a/tests/validation/reference/DFT.cpp b/tests/validation/reference/DFT.cpp
index fd126c7d73..2b03c270ac 100644
--- a/tests/validation/reference/DFT.cpp
+++ b/tests/validation/reference/DFT.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2020 Arm Limited.
+ * Copyright (c) 2019-2020, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -400,10 +400,10 @@ SimpleTensor<T> conv2d_dft(const SimpleTensor<T> &src, const SimpleTensor<T> &w,
auto padded_src = pad_layer(src, padding_in);
// Flip weights
- std::vector<uint32_t> axis_v = { 0, 1 };
- SimpleTensor<uint32_t> axis{ TensorShape(2U), DataType::U32 };
+ std::vector<uint32_t> axis_v = { 0, 1 };
+ SimpleTensor<int32_t> axis{ TensorShape(2U), DataType::S32 };
std::copy(axis_v.begin(), axis_v.begin() + axis.shape().x(), axis.data());
- auto flipped_w = reverse(w, axis);
+ auto flipped_w = reverse(w, axis, /* use_inverted_axis */ false);
// Pad weights to have the same size as input
const PaddingList paddings_w = { { 0, src.shape()[0] - 1 }, { 0, src.shape()[1] - 1 } };
diff --git a/tests/validation/reference/Reverse.cpp b/tests/validation/reference/Reverse.cpp
index c6c4614278..5fd15b5bfc 100644
--- a/tests/validation/reference/Reverse.cpp
+++ b/tests/validation/reference/Reverse.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2020, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -35,8 +35,9 @@ namespace validation
namespace reference
{
template <typename T>
-SimpleTensor<T> reverse(const SimpleTensor<T> &src, const SimpleTensor<uint32_t> &axis)
+SimpleTensor<T> reverse(const SimpleTensor<T> &src, const SimpleTensor<int32_t> &axis, bool use_inverted_axis)
{
+ ARM_COMPUTE_ERROR_ON(src.shape().num_dimensions() > 4);
ARM_COMPUTE_ERROR_ON(axis.shape().num_dimensions() > 1);
ARM_COMPUTE_ERROR_ON(axis.shape().x() > 4);
@@ -48,10 +49,32 @@ SimpleTensor<T> reverse(const SimpleTensor<T> &src, const SimpleTensor<uint32_t>
const unsigned int depth = src.shape()[2];
const unsigned int batches = src.shape()[3];
+ const int rank = src.shape().num_dimensions();
+
std::array<bool, 4> to_reverse = { { false, false, false, false } };
for(int i = 0; i < axis.num_elements(); ++i)
{
- to_reverse[axis[i]] = true;
+ int axis_i = axis[i];
+
+ // The values of axis tensor must be between [-rank, rank-1].
+ if((axis_i < -rank) || (axis_i >= rank))
+ {
+ ARM_COMPUTE_ERROR("the valuses of the axis tensor must be within [-rank, rank-1].");
+ }
+
+ // In case of negative axis value i.e targeted axis(i) = rank + axis(i)
+ if(axis_i < 0)
+ {
+ axis_i = rank + axis_i;
+ }
+
+ // Reverse ACL axis indices convention i.e. (inverted)axis = (tensor_rank - 1) - axis
+ if(use_inverted_axis)
+ {
+ axis_i = (rank - 1) - axis_i;
+ }
+
+ to_reverse[axis_i] = true;
}
const uint32_t num_elements = src.num_elements();
@@ -73,9 +96,9 @@ SimpleTensor<T> reverse(const SimpleTensor<T> &src, const SimpleTensor<uint32_t>
return dst;
}
-template SimpleTensor<uint8_t> reverse(const SimpleTensor<uint8_t> &src, const SimpleTensor<uint32_t> &axis);
-template SimpleTensor<half> reverse(const SimpleTensor<half> &src, const SimpleTensor<uint32_t> &axis);
-template SimpleTensor<float> reverse(const SimpleTensor<float> &src, const SimpleTensor<uint32_t> &axis);
+template SimpleTensor<uint8_t> reverse(const SimpleTensor<uint8_t> &src, const SimpleTensor<int32_t> &axis, bool use_inverted_axis);
+template SimpleTensor<half> reverse(const SimpleTensor<half> &src, const SimpleTensor<int32_t> &axis, bool use_inverted_axis);
+template SimpleTensor<float> reverse(const SimpleTensor<float> &src, const SimpleTensor<int32_t> &axis, bool use_inverted_axis);
} // namespace reference
} // namespace validation
} // namespace test
diff --git a/tests/validation/reference/Reverse.h b/tests/validation/reference/Reverse.h
index 4a28da7270..30926b05a5 100644
--- a/tests/validation/reference/Reverse.h
+++ b/tests/validation/reference/Reverse.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 Arm Limited.
+ * Copyright (c) 2018-2019, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_TEST_REVERSE_H
-#define ARM_COMPUTE_TEST_REVERSE_H
+#ifndef ACL_TESTS_VALIDATION_REFERENCE_REVERSE_H
+#define ACL_TESTS_VALIDATION_REFERENCE_REVERSE_H
#include "tests/SimpleTensor.h"
@@ -35,9 +35,9 @@ namespace validation
namespace reference
{
template <typename T>
-SimpleTensor<T> reverse(const SimpleTensor<T> &src, const SimpleTensor<uint32_t> &axis);
+SimpleTensor<T> reverse(const SimpleTensor<T> &src, const SimpleTensor<int32_t> &axis, bool use_inverted_axis = false);
} // namespace reference
} // namespace validation
} // namespace test
} // namespace arm_compute
-#endif /* ARM_COMPUTE_TEST_REVERSE_H */
+#endif // ACL_TESTS_VALIDATION_REFERENCE_REVERSE_H