aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/NEON/WidthConcatenateLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/NEON/WidthConcatenateLayer.cpp')
-rw-r--r--tests/validation/NEON/WidthConcatenateLayer.cpp35
1 files changed, 21 insertions, 14 deletions
diff --git a/tests/validation/NEON/WidthConcatenateLayer.cpp b/tests/validation/NEON/WidthConcatenateLayer.cpp
index 6e94e92d05..dba14ebb35 100644
--- a/tests/validation/NEON/WidthConcatenateLayer.cpp
+++ b/tests/validation/NEON/WidthConcatenateLayer.cpp
@@ -22,7 +22,7 @@
* SOFTWARE.
*/
#include "arm_compute/core/Types.h"
-#include "arm_compute/runtime/NEON/functions/NEWidthConcatenateLayer.h"
+#include "arm_compute/runtime/NEON/functions/NEConcatenateLayer.h"
#include "arm_compute/runtime/Tensor.h"
#include "arm_compute/runtime/TensorAllocator.h"
#include "tests/NEON/Accessor.h"
@@ -31,7 +31,7 @@
#include "tests/framework/Macros.h"
#include "tests/framework/datasets/Datasets.h"
#include "tests/validation/Validation.h"
-#include "tests/validation/fixtures/WidthConcatenateLayerFixture.h"
+#include "tests/validation/fixtures/ConcatenateLayerFixture.h"
namespace arm_compute
{
@@ -72,27 +72,30 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
inputs_vector_info_raw.emplace_back(&input);
}
- bool is_valid = bool(NEWidthConcatenateLayer::validate(inputs_vector_info_raw,
- &output_info.clone()->set_is_resizable(false)));
+ bool is_valid = bool(NEConcatenateLayer::validate(inputs_vector_info_raw,
+ &output_info.clone()->set_is_resizable(false),DataLayoutDimension::WIDTH));
ARM_COMPUTE_EXPECT(is_valid == expected, framework::LogLevel::ERRORS);
}
// clang-format on
// *INDENT-ON*
-
template <typename T>
-using NEWidthConcatenateLayerFixture = WidthConcatenateLayerValidationFixture<Tensor, ITensor, Accessor, NEWidthConcatenateLayer, T>;
+using NEWidthConcatenateLayerFixture = ConcatenateLayerValidationFixture<Tensor, ITensor, Accessor, NEConcatenateLayer, T>;
TEST_SUITE(Float)
TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEWidthConcatenateLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(concat(datasets::Small2DShapes(), datasets::Tiny4DShapes()),
+FIXTURE_DATA_TEST_CASE(RunSmall, NEWidthConcatenateLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(concat(datasets::Small2DShapes(), datasets::Tiny4DShapes()),
framework::dataset::make("DataType",
- DataType::F32)))
+ DataType::F32)),
+ framework::dataset::make("Axis", 0)))
+
{
// Validate output
validate(Accessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEWidthConcatenateLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(datasets::WidthConcatenateLayerShapes(), framework::dataset::make("DataType",
- DataType::F32)))
+FIXTURE_DATA_TEST_CASE(RunLarge, NEWidthConcatenateLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::ConcatenateLayerShapes(), framework::dataset::make("DataType",
+ DataType::F32)),
+ framework::dataset::make("Axis", 0)))
+
{
// Validate output
validate(Accessor(_target), _reference);
@@ -102,15 +105,19 @@ TEST_SUITE_END()
TEST_SUITE(Quantized)
TEST_SUITE(QASYMM8)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEWidthConcatenateLayerFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(concat(datasets::Small2DShapes(), datasets::Tiny4DShapes()),
+FIXTURE_DATA_TEST_CASE(RunSmall, NEWidthConcatenateLayerFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(concat(datasets::Small2DShapes(), datasets::Tiny4DShapes()),
framework::dataset::make("DataType",
- DataType::QASYMM8)))
+ DataType::QASYMM8)),
+ framework::dataset::make("Axis", 0)))
+
{
// Validate output
validate(Accessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEWidthConcatenateLayerFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(datasets::WidthConcatenateLayerShapes(), framework::dataset::make("DataType",
- DataType::QASYMM8)))
+FIXTURE_DATA_TEST_CASE(RunLarge, NEWidthConcatenateLayerFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::ConcatenateLayerShapes(), framework::dataset::make("DataType",
+ DataType::QASYMM8)),
+ framework::dataset::make("Axis", 0)))
+
{
// Validate output
validate(Accessor(_target), _reference);