From cc1f6c94f1fc3b5d5ccbd5aa43e2a08487664f50 Mon Sep 17 00:00:00 2001 From: morgolock Date: Tue, 24 Mar 2020 09:26:48 +0000 Subject: MLCE-166: Add support for extracting indices in NEPoolingLayer 2x2 NCHW * Added initial support for pooling indices * Only supported for NCHW Poolsize 2 Change-Id: I92ce767e64fcc01aae89411064b4cb2be272a1e9 Signed-off-by: morgolock Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2927 Comments-Addressed: Arm Jenkins Reviewed-by: Georgios Pinitas Reviewed-by: Sang-Hoon Park Tested-by: Arm Jenkins --- src/runtime/CL/functions/CLPoolingLayer.cpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) (limited to 'src/runtime/CL/functions') diff --git a/src/runtime/CL/functions/CLPoolingLayer.cpp b/src/runtime/CL/functions/CLPoolingLayer.cpp index ebdae0b8ad..9c4fa4a2ba 100644 --- a/src/runtime/CL/functions/CLPoolingLayer.cpp +++ b/src/runtime/CL/functions/CLPoolingLayer.cpp @@ -30,14 +30,13 @@ namespace arm_compute { -void CLPoolingLayer::configure(ICLTensor *input, ICLTensor *output, const PoolingLayerInfo &pool_info) +void CLPoolingLayer::configure(ICLTensor *input, ICLTensor *output, const PoolingLayerInfo &pool_info, ICLTensor *indices) { ARM_COMPUTE_ERROR_ON_NULLPTR(input); - // Configure pooling kernel auto k = arm_compute::support::cpp14::make_unique(); k->set_target(CLScheduler::get().target()); - k->configure(input, output, pool_info); + k->configure(input, output, pool_info, indices); _kernel = std::move(k); const DataType data_type = input->info()->data_type(); @@ -81,8 +80,8 @@ void CLPoolingLayer::configure(ICLTensor *input, ICLTensor *output, const Poolin CLScheduler::get().tune_kernel_static(*_kernel); } -Status CLPoolingLayer::validate(const ITensorInfo *input, const ITensorInfo *output, const PoolingLayerInfo &pool_info) +Status CLPoolingLayer::validate(const ITensorInfo *input, const ITensorInfo *output, const PoolingLayerInfo &pool_info, const ITensorInfo *indices) { - return CLPoolingLayerKernel::validate(input, output, pool_info); + return CLPoolingLayerKernel::validate(input, output, pool_info, indices); } } // namespace arm_compute -- cgit v1.2.1