From cc1f6c94f1fc3b5d5ccbd5aa43e2a08487664f50 Mon Sep 17 00:00:00 2001 From: morgolock Date: Tue, 24 Mar 2020 09:26:48 +0000 Subject: MLCE-166: Add support for extracting indices in NEPoolingLayer 2x2 NCHW * Added initial support for pooling indices * Only supported for NCHW Poolsize 2 Change-Id: I92ce767e64fcc01aae89411064b4cb2be272a1e9 Signed-off-by: morgolock Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2927 Comments-Addressed: Arm Jenkins Reviewed-by: Georgios Pinitas Reviewed-by: Sang-Hoon Park Tested-by: Arm Jenkins --- tests/validation/NEON/PoolingLayer.cpp | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) (limited to 'tests/validation/NEON/PoolingLayer.cpp') diff --git a/tests/validation/NEON/PoolingLayer.cpp b/tests/validation/NEON/PoolingLayer.cpp index 1012320b0d..a5876dcd0a 100644 --- a/tests/validation/NEON/PoolingLayer.cpp +++ b/tests/validation/NEON/PoolingLayer.cpp @@ -111,14 +111,33 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( // clang-format on // *INDENT-ON* +template +using NEPoolingLayerIndicesFixture = PoolingLayerIndicesValidationFixture; + template using NEPoolingLayerFixture = PoolingLayerValidationFixture; template using NESpecialPoolingLayerFixture = SpecialPoolingLayerValidationFixture; +const auto PoolingLayerIndicesDatasetFPSmall = combine(combine(combine(framework::dataset::make("PoolType", { PoolingType::MAX }), framework::dataset::make("PoolingSize", { Size2D(2, 2) })), + framework::dataset::make("PadStride", { PadStrideInfo(1, 1, 0, 0), PadStrideInfo(2, 1, 0, 0) })), + framework::dataset::make("ExcludePadding", { true, false })); + TEST_SUITE(Float) TEST_SUITE(FP32) +FIXTURE_DATA_TEST_CASE(RunIndices, NEPoolingLayerIndicesFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), combine(PoolingLayerIndicesDatasetFPSmall, + framework::dataset::make("DataType", + DataType::F32))), + framework::dataset::make("DataLayout", { DataLayout::NCHW }) + + )) +{ + // Validate output + validate(Accessor(_target), _reference, tolerance_f32); + validate(Accessor(_target_indices), _ref_indices); +} + FIXTURE_DATA_TEST_CASE(RunSpecial, NESpecialPoolingLayerFixture, framework::DatasetMode::ALL, datasets::PoolingLayerDatasetSpecial() * framework::dataset::make("DataType", DataType::F32)) { // Validate output -- cgit v1.2.1