aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/CL/WidthConcatenateLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/CL/WidthConcatenateLayer.cpp')
-rw-r--r--tests/validation/CL/WidthConcatenateLayer.cpp48
1 files changed, 28 insertions, 20 deletions
diff --git a/tests/validation/CL/WidthConcatenateLayer.cpp b/tests/validation/CL/WidthConcatenateLayer.cpp
index 647e0413a1..493320b9ad 100644
--- a/tests/validation/CL/WidthConcatenateLayer.cpp
+++ b/tests/validation/CL/WidthConcatenateLayer.cpp
@@ -24,14 +24,14 @@
#include "arm_compute/core/Types.h"
#include "arm_compute/runtime/CL/CLTensor.h"
#include "arm_compute/runtime/CL/CLTensorAllocator.h"
-#include "arm_compute/runtime/CL/functions/CLWidthConcatenateLayer.h"
+#include "arm_compute/runtime/CL/functions/CLConcatenateLayer.h"
#include "tests/CL/CLAccessor.h"
#include "tests/datasets/ShapeDatasets.h"
#include "tests/framework/Asserts.h"
#include "tests/framework/Macros.h"
#include "tests/framework/datasets/Datasets.h"
#include "tests/validation/Validation.h"
-#include "tests/validation/fixtures/WidthConcatenateLayerFixture.h"
+#include "tests/validation/fixtures/ConcatenateLayerFixture.h"
namespace arm_compute
{
@@ -72,8 +72,7 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
inputs_vector_info_raw.emplace_back(&input);
}
- bool is_valid = bool(CLWidthConcatenateLayer::validate(inputs_vector_info_raw,
- &output_info.clone()->set_is_resizable(false)));
+ bool is_valid = bool(CLConcatenateLayer::validate(inputs_vector_info_raw,&output_info.clone()->set_is_resizable(false),DataLayoutDimension::WIDTH ));
ARM_COMPUTE_EXPECT(is_valid == expected, framework::LogLevel::ERRORS);
}
// clang-format on
@@ -93,26 +92,30 @@ TEST_CASE(Configuration, framework::DatasetMode::ALL)
ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
// Create and configure function
- CLWidthConcatenateLayer concat_layer;
+ CLConcatenateLayer concat_layer;
- concat_layer.configure({ &src1, &src2, &src3 }, &dst);
+ concat_layer.configure({ &src1, &src2, &src3 }, &dst, DataLayoutDimension::WIDTH);
}
template <typename T>
-using CLWidthConcatenateLayerFixture = WidthConcatenateLayerValidationFixture<CLTensor, ICLTensor, CLAccessor, CLWidthConcatenateLayer, T>;
+using CLWidthConcatenateLayerFixture = ConcatenateLayerValidationFixture<CLTensor, ICLTensor, CLAccessor, CLConcatenateLayer, T>;
TEST_SUITE(Float)
TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLWidthConcatenateLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(concat(datasets::Small2DShapes(), datasets::Tiny4DShapes()),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLWidthConcatenateLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(concat(datasets::Small2DShapes(), datasets::Tiny4DShapes()),
framework::dataset::make("DataType",
- DataType::F16)))
+ DataType::F16)),
+ framework::dataset::make("Axis", 0)))
+
{
// Validate output
validate(CLAccessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, CLWidthConcatenateLayerFixture<half>, framework::DatasetMode::NIGHTLY, combine(concat(datasets::Large2DShapes(), datasets::Small4DShapes()),
- framework::dataset::make("DataType",
- DataType::F16)))
+FIXTURE_DATA_TEST_CASE(RunLarge, CLWidthConcatenateLayerFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(concat(datasets::Large2DShapes(), datasets::Small4DShapes()),
+ framework::dataset::make("DataType",
+ DataType::F16)),
+ framework::dataset::make("Axis", 0)))
+
{
// Validate output
validate(CLAccessor(_target), _reference);
@@ -120,15 +123,18 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLWidthConcatenateLayerFixture<half>, framework
TEST_SUITE_END()
TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLWidthConcatenateLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(concat(datasets::Small2DShapes(), datasets::Tiny4DShapes()),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLWidthConcatenateLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(concat(datasets::Small2DShapes(), datasets::Tiny4DShapes()),
framework::dataset::make("DataType",
- DataType::F32)))
+ DataType::F32)),
+ framework::dataset::make("Axis", 0)))
+
{
// Validate output
validate(CLAccessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, CLWidthConcatenateLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(datasets::WidthConcatenateLayerShapes(), framework::dataset::make("DataType",
- DataType::F32)))
+FIXTURE_DATA_TEST_CASE(RunLarge, CLWidthConcatenateLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::ConcatenateLayerShapes(), framework::dataset::make("DataType",
+ DataType::F32)),
+ framework::dataset::make("Axis", 0)))
{
// Validate output
validate(CLAccessor(_target), _reference);
@@ -138,15 +144,17 @@ TEST_SUITE_END()
TEST_SUITE(Quantized)
TEST_SUITE(QASYMM8)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLWidthConcatenateLayerFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(concat(datasets::Small2DShapes(), datasets::Tiny4DShapes()),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLWidthConcatenateLayerFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(concat(datasets::Small2DShapes(), datasets::Tiny4DShapes()),
framework::dataset::make("DataType",
- DataType::QASYMM8)))
+ DataType::QASYMM8)),
+ framework::dataset::make("Axis", 0)))
{
// Validate output
validate(CLAccessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, CLWidthConcatenateLayerFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(datasets::WidthConcatenateLayerShapes(), framework::dataset::make("DataType",
- DataType::QASYMM8)))
+FIXTURE_DATA_TEST_CASE(RunLarge, CLWidthConcatenateLayerFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::ConcatenateLayerShapes(), framework::dataset::make("DataType",
+ DataType::QASYMM8)),
+ framework::dataset::make("Axis", 0)))
{
// Validate output
validate(CLAccessor(_target), _reference);