aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorAdnan AlSinan <adnan.alsinan@arm.com>2023-10-24 11:05:56 +0100
committerAdnan AlSinan <adnan.alsinan@arm.com>2023-10-31 14:23:31 +0000
commit704c22f1373e1276acb43c71e7e17048271bbc03 (patch)
tree7cf8b5d4730c6482229a228215dd80b794088735 /tests
parent8f4b3df4c59c7b1c3fbea5b559862fcefeba14bf (diff)
downloadComputeLibrary-704c22f1373e1276acb43c71e7e17048271bbc03.tar.gz
[GPU] Update Reverse layer to allow negative axis and reversed axis order
- Adds option to use negative axis and inverted axis. - Adds validation tests for the above. Resolves COMPMID-6459 Change-Id: I88afd845d078f92c82ec8529ce7241fccd4c417e Signed-off-by: Adnan AlSinan <adnan.alsinan@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10523 Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/validation/CL/Reverse.cpp29
-rw-r--r--tests/validation/fixtures/ReverseFixture.h31
2 files changed, 16 insertions, 44 deletions
diff --git a/tests/validation/CL/Reverse.cpp b/tests/validation/CL/Reverse.cpp
index ff46ba64ad..82effc2136 100644
--- a/tests/validation/CL/Reverse.cpp
+++ b/tests/validation/CL/Reverse.cpp
@@ -44,7 +44,7 @@ namespace validation
using framework::dataset::make;
namespace
{
-auto run_small_dataset = combine(datasets::SmallShapes(), datasets::Tiny1DShapes());
+auto run_small_dataset = combine(datasets::Small3DShapes(), datasets::Tiny1DShapes());
auto run_large_dataset = combine(datasets::LargeShapes(), datasets::Tiny1DShapes());
} // namespace
@@ -80,7 +80,8 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
{
Status s = CLReverse::validate(&src_info.clone()->set_is_resizable(false),
&dst_info.clone()->set_is_resizable(false),
- &axis_info.clone()->set_is_resizable(false));
+ &axis_info.clone()->set_is_resizable(false),
+ false);
ARM_COMPUTE_EXPECT(bool(s) == expected, framework::LogLevel::ERRORS);
}
// clang-format on
@@ -97,8 +98,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall,
combine(
run_small_dataset,
make("DataType", DataType::F16),
- make("use_negative_axis", { false }),
- make("use_inverted_axis", { false })))
+ make("use_negative_axis", { true, false }),
+ make("use_inverted_axis", { true, false })))
{
// Validate output
validate(CLAccessor(_target), _reference);
@@ -110,8 +111,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge,
combine(
run_large_dataset,
make("DataType", DataType::F16),
- make("use_negative_axis", { false }),
- make("use_inverted_axis", { false })))
+ make("use_negative_axis", { true, false }),
+ make("use_inverted_axis", { true, false })))
{
// Validate output
validate(CLAccessor(_target), _reference);
@@ -125,8 +126,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall,
combine(
run_small_dataset,
make("DataType", DataType::F32),
- make("use_negative_axis", { false }),
- make("use_inverted_axis", { false })))
+ make("use_negative_axis", { true, false }),
+ make("use_inverted_axis", { true, false })))
{
// Validate output
validate(CLAccessor(_target), _reference);
@@ -138,8 +139,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge,
combine(
run_large_dataset,
make("DataType", DataType::F32),
- make("use_negative_axis", { false }),
- make("use_inverted_axis", { false })))
+ make("use_negative_axis", { true, false }),
+ make("use_inverted_axis", { true, false })))
{
// Validate output
validate(CLAccessor(_target), _reference);
@@ -155,8 +156,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall,
combine(
run_small_dataset,
make("DataType", DataType::QASYMM8),
- make("use_negative_axis", { false }),
- make("use_inverted_axis", { false })))
+ make("use_negative_axis", { true, false }),
+ make("use_inverted_axis", { true, false })))
{
// Validate output
validate(CLAccessor(_target), _reference);
@@ -168,8 +169,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge,
combine(
run_large_dataset,
make("DataType", DataType::QASYMM8),
- make("use_negative_axis", { false }),
- make("use_inverted_axis", { false })))
+ make("use_negative_axis", { true, false }),
+ make("use_inverted_axis", { true, false })))
{
// Validate output
validate(CLAccessor(_target), _reference);
diff --git a/tests/validation/fixtures/ReverseFixture.h b/tests/validation/fixtures/ReverseFixture.h
index d53945191e..856bff7b12 100644
--- a/tests/validation/fixtures/ReverseFixture.h
+++ b/tests/validation/fixtures/ReverseFixture.h
@@ -27,9 +27,6 @@
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/TensorShape.h"
#include "arm_compute/core/Types.h"
-#ifdef ARM_COMPUTE_OPENCL_ENABLED
-#include "arm_compute/runtime/CL/functions/CLReverse.h"
-#endif // ARM_COMPUTE_OPENCL_ENABLED
#include "arm_compute/runtime/Tensor.h"
#include "tests/AssetsLibrary.h"
#include "tests/Globals.h"
@@ -44,31 +41,6 @@ namespace test
{
namespace validation
{
-namespace
-{
-template <typename ReverseFunction, typename TensorType>
-#ifdef ARM_COMPUTE_OPENCL_ENABLED
-std::enable_if_t < !std::is_same<ReverseFunction, CLReverse>::value, void >
-#else // ARM_COMPUTE_OPENCL_ENABLED
-void
-#endif // ARM_COMPUTE_OPENCL_ENABLED
-configureReverse(ReverseFunction &func, TensorType &src, TensorType &axis, TensorType &dst, bool use_inverted_axis)
-{
- func.configure(&src, &dst, &axis, use_inverted_axis);
-}
-
-#ifdef ARM_COMPUTE_OPENCL_ENABLED
-template <typename ReverseFunction, typename TensorType>
-std::enable_if_t<std::is_same<ReverseFunction, CLReverse>::value, void>
-configureReverse(ReverseFunction &func, TensorType &src, TensorType &axis, TensorType &dst, bool use_inverted_axis)
-{
- ARM_COMPUTE_UNUSED(use_inverted_axis);
- func.configure(&src, &dst, &axis);
-}
-
-#endif // ARM_COMPUTE_OPENCL_ENABLED
-} //namespace
-
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
class ReverseValidationFixture : public framework::Fixture
{
@@ -113,8 +85,7 @@ protected:
// Create and configure function
FunctionType reverse_func;
-
- configureReverse(reverse_func, src, axis, dst, use_inverted_axis);
+ reverse_func.configure(&src, &dst, &axis, use_inverted_axis);
ARM_COMPUTE_ASSERT(src.info()->is_resizable());
ARM_COMPUTE_ASSERT(axis.info()->is_resizable());