aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/CL/Reverse.cpp
diff options
context:
space:
mode:
authorAdnan AlSinan <adnan.alsinan@arm.com>2023-09-18 14:49:45 +0100
committerAdnan AlSinan <adnan.alsinan@arm.com>2023-09-27 15:13:29 +0000
commitbdcb4c148ee2fdeaaddf4cf1e57bbb0de02bb894 (patch)
treeee9743ddfe42b800bbc54dc3c273c188cb779017 /tests/validation/CL/Reverse.cpp
parent729099c5d134c2c34459a2bdbd5453ad4ca68cac (diff)
downloadComputeLibrary-bdcb4c148ee2fdeaaddf4cf1e57bbb0de02bb894.tar.gz
Implement tflite compliant reverse for CPU
- Add support for negative axis values. - Add option to use opposite ACL convention for dimension addressing. - Add validation tests for the mentioned additions. Resolves COMPMID-6497 Change-Id: I9174b201c3adc070766cc6cffcbe4ec1fe5ec1c3 Signed-off-by: Adnan AlSinan <adnan.alsinan@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10335 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: SiCong Li <sicong.li@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/CL/Reverse.cpp')
-rw-r--r--tests/validation/CL/Reverse.cpp47
1 files changed, 36 insertions, 11 deletions
diff --git a/tests/validation/CL/Reverse.cpp b/tests/validation/CL/Reverse.cpp
index 11df0e7803..ff46ba64ad 100644
--- a/tests/validation/CL/Reverse.cpp
+++ b/tests/validation/CL/Reverse.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2020, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -41,6 +41,7 @@ namespace test
{
namespace validation
{
+using framework::dataset::make;
namespace
{
auto run_small_dataset = combine(datasets::SmallShapes(), datasets::Tiny1DShapes());
@@ -53,28 +54,28 @@ TEST_SUITE(Reverse)
// *INDENT-OFF*
// clang-format off
DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
- framework::dataset::make("InputInfo", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S8), // Invalid axis datatype
+ make("InputInfo", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S8), // Invalid axis datatype
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid axis shape
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid axis length (> 4)
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Mismatching shapes
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(2U), 1, DataType::U8),
}),
- framework::dataset::make("OutputInfo", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S8),
+ make("OutputInfo", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S8),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(2U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(2U), 1, DataType::U8),
})),
- framework::dataset::make("AxisInfo",{ TensorInfo(TensorShape(3U), 1, DataType::U8),
+ make("AxisInfo",{ TensorInfo(TensorShape(3U), 1, DataType::U8),
TensorInfo(TensorShape(2U, 10U), 1, DataType::U32),
TensorInfo(TensorShape(8U), 1, DataType::U32),
TensorInfo(TensorShape(2U), 1, DataType::U32),
TensorInfo(TensorShape(2U), 1, DataType::U32),
TensorInfo(TensorShape(2U), 1, DataType::U32),
})),
- framework::dataset::make("Expected", { false, false, false, false, true, true})),
+ make("Expected", { false, false, false, false, true, true})),
src_info, dst_info, axis_info, expected)
{
Status s = CLReverse::validate(&src_info.clone()->set_is_resizable(false),
@@ -93,7 +94,11 @@ TEST_SUITE(F16)
FIXTURE_DATA_TEST_CASE(RunSmall,
CLReverseFixture<half>,
framework::DatasetMode::PRECOMMIT,
- combine(run_small_dataset, framework::dataset::make("DataType", DataType::F16)))
+ combine(
+ run_small_dataset,
+ make("DataType", DataType::F16),
+ make("use_negative_axis", { false }),
+ make("use_inverted_axis", { false })))
{
// Validate output
validate(CLAccessor(_target), _reference);
@@ -102,7 +107,11 @@ FIXTURE_DATA_TEST_CASE(RunSmall,
FIXTURE_DATA_TEST_CASE(RunLarge,
CLReverseFixture<half>,
framework::DatasetMode::NIGHTLY,
- combine(run_large_dataset, framework::dataset::make("DataType", DataType::F16)))
+ combine(
+ run_large_dataset,
+ make("DataType", DataType::F16),
+ make("use_negative_axis", { false }),
+ make("use_inverted_axis", { false })))
{
// Validate output
validate(CLAccessor(_target), _reference);
@@ -113,7 +122,11 @@ TEST_SUITE(FP32)
FIXTURE_DATA_TEST_CASE(RunSmall,
CLReverseFixture<float>,
framework::DatasetMode::PRECOMMIT,
- combine(run_small_dataset, framework::dataset::make("DataType", DataType::F32)))
+ combine(
+ run_small_dataset,
+ make("DataType", DataType::F32),
+ make("use_negative_axis", { false }),
+ make("use_inverted_axis", { false })))
{
// Validate output
validate(CLAccessor(_target), _reference);
@@ -122,7 +135,11 @@ FIXTURE_DATA_TEST_CASE(RunSmall,
FIXTURE_DATA_TEST_CASE(RunLarge,
CLReverseFixture<float>,
framework::DatasetMode::NIGHTLY,
- combine(run_large_dataset, framework::dataset::make("DataType", DataType::F32)))
+ combine(
+ run_large_dataset,
+ make("DataType", DataType::F32),
+ make("use_negative_axis", { false }),
+ make("use_inverted_axis", { false })))
{
// Validate output
validate(CLAccessor(_target), _reference);
@@ -135,7 +152,11 @@ TEST_SUITE(QASYMM8)
FIXTURE_DATA_TEST_CASE(RunSmall,
CLReverseFixture<uint8_t>,
framework::DatasetMode::PRECOMMIT,
- combine(run_small_dataset, framework::dataset::make("DataType", DataType::QASYMM8)))
+ combine(
+ run_small_dataset,
+ make("DataType", DataType::QASYMM8),
+ make("use_negative_axis", { false }),
+ make("use_inverted_axis", { false })))
{
// Validate output
validate(CLAccessor(_target), _reference);
@@ -144,7 +165,11 @@ FIXTURE_DATA_TEST_CASE(RunSmall,
FIXTURE_DATA_TEST_CASE(RunLarge,
CLReverseFixture<uint8_t>,
framework::DatasetMode::NIGHTLY,
- combine(run_large_dataset, framework::dataset::make("DataType", DataType::QASYMM8)))
+ combine(
+ run_large_dataset,
+ make("DataType", DataType::QASYMM8),
+ make("use_negative_axis", { false }),
+ make("use_inverted_axis", { false })))
{
// Validate output
validate(CLAccessor(_target), _reference);