From e383c35c336002ce15945ed48facd7d4ba715aa8 Mon Sep 17 00:00:00 2001 From: morgolock Date: Fri, 3 Apr 2020 16:57:46 +0100 Subject: MLCE-166: Add support for extracting indices in NEPoolingLayer 2x2 NHWC * Added support for pooling indices in NHWC Poolsize 2x2 Change-Id: Ib2a3468e794f58bbf2c03aba9f6b184b9d76b183 Signed-off-by: morgolock Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2997 Tested-by: Arm Jenkins Reviewed-by: Manuel Bottini Reviewed-by: Michele Di Giorgio Comments-Addressed: Arm Jenkins --- tests/validation/fixtures/PoolingLayerFixture.h | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) (limited to 'tests/validation/fixtures') diff --git a/tests/validation/fixtures/PoolingLayerFixture.h b/tests/validation/fixtures/PoolingLayerFixture.h index 7f2d7ac225..eb40cea0c2 100644 --- a/tests/validation/fixtures/PoolingLayerFixture.h +++ b/tests/validation/fixtures/PoolingLayerFixture.h @@ -35,7 +35,6 @@ #include "tests/framework/Fixture.h" #include "tests/validation/reference/PoolingLayer.h" #include - namespace arm_compute { namespace test @@ -59,7 +58,7 @@ public: _pool_info = pool_info; _target = compute_target(shape, pool_info, data_type, data_layout, input_qinfo, output_qinfo, indices); - _reference = compute_reference(shape, pool_info, data_type, input_qinfo, output_qinfo, indices); + _reference = compute_reference(shape, pool_info, data_type, data_layout, input_qinfo, output_qinfo, indices); } protected: @@ -92,7 +91,7 @@ protected: TensorType src = create_tensor(shape, data_type, 1, input_qinfo, data_layout); const TensorShape dst_shape = misc::shape_calculator::compute_pool_shape(*(src.info()), info); TensorType dst = create_tensor(dst_shape, data_type, 1, output_qinfo, data_layout); - _target_indices = create_tensor(dst_shape, DataType::U32, 1); + _target_indices = create_tensor(dst_shape, DataType::U32, 1, output_qinfo, data_layout); // Create and configure function FunctionType pool_layer; @@ -120,15 +119,14 @@ protected: return dst; } - SimpleTensor compute_reference(const TensorShape &shape, PoolingLayerInfo info, DataType data_type, + SimpleTensor compute_reference(TensorShape shape, PoolingLayerInfo info, DataType data_type, DataLayout data_layout, QuantizationInfo input_qinfo, QuantizationInfo output_qinfo, bool indices) { // Create reference - SimpleTensor src{ shape, data_type, 1, input_qinfo }; + SimpleTensor src(shape, data_type, 1, input_qinfo); // Fill reference fill(src); - - return reference::pooling_layer(src, info, output_qinfo, indices ? &_ref_indices : nullptr); + return reference::pooling_layer(src, info, output_qinfo, indices ? &_ref_indices : nullptr, data_layout); } TensorType _target{}; -- cgit v1.2.1