aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/NEON/PoolingLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/NEON/PoolingLayer.cpp')
-rw-r--r--tests/validation/NEON/PoolingLayer.cpp19
1 files changed, 19 insertions, 0 deletions
diff --git a/tests/validation/NEON/PoolingLayer.cpp b/tests/validation/NEON/PoolingLayer.cpp
index 1012320b0d..a5876dcd0a 100644
--- a/tests/validation/NEON/PoolingLayer.cpp
+++ b/tests/validation/NEON/PoolingLayer.cpp
@@ -112,13 +112,32 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
// *INDENT-ON*
template <typename T>
+using NEPoolingLayerIndicesFixture = PoolingLayerIndicesValidationFixture<Tensor, Accessor, NEPoolingLayer, T>;
+
+template <typename T>
using NEPoolingLayerFixture = PoolingLayerValidationFixture<Tensor, Accessor, NEPoolingLayer, T>;
template <typename T>
using NESpecialPoolingLayerFixture = SpecialPoolingLayerValidationFixture<Tensor, Accessor, NEPoolingLayer, T>;
+const auto PoolingLayerIndicesDatasetFPSmall = combine(combine(combine(framework::dataset::make("PoolType", { PoolingType::MAX }), framework::dataset::make("PoolingSize", { Size2D(2, 2) })),
+ framework::dataset::make("PadStride", { PadStrideInfo(1, 1, 0, 0), PadStrideInfo(2, 1, 0, 0) })),
+ framework::dataset::make("ExcludePadding", { true, false }));
+
TEST_SUITE(Float)
TEST_SUITE(FP32)
+FIXTURE_DATA_TEST_CASE(RunIndices, NEPoolingLayerIndicesFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), combine(PoolingLayerIndicesDatasetFPSmall,
+ framework::dataset::make("DataType",
+ DataType::F32))),
+ framework::dataset::make("DataLayout", { DataLayout::NCHW })
+
+ ))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_f32);
+ validate(Accessor(_target_indices), _ref_indices);
+}
+
FIXTURE_DATA_TEST_CASE(RunSpecial, NESpecialPoolingLayerFixture<float>, framework::DatasetMode::ALL, datasets::PoolingLayerDatasetSpecial() * framework::dataset::make("DataType", DataType::F32))
{
// Validate output