From bbf2e7477be984702e1a51f2a23910ee8349b867 Mon Sep 17 00:00:00 2001 From: Adnan AlSinan Date: Wed, 22 Feb 2023 12:15:14 +0000 Subject: Add support for kernel indices in Maxpool - Add a max pooling implementation that returns kernel indices. - Add a parameter in pooling info object to pick kernel indices impl. - Add validation tests. Resolves: [ONCPUML-1187] Signed-off-by: Adnan AlSinan Change-Id: I485ef1604f676ee14d5f7f62d33699e49c38e4d3 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9192 Reviewed-by: Gunes Bayir Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- arm_compute/core/Types.h | 21 +++- src/cpu/kernels/CpuPool2dKernel.cpp | 11 +- src/cpu/kernels/pool2d/neon/fp32.cpp | 143 ++++++++++++++++++++++-- src/cpu/operators/CpuPool2d.cpp | 7 +- src/cpu/operators/CpuPool2d.h | 3 +- tests/validation/CL/PoolingLayer.cpp | 10 +- tests/validation/NEON/PoolingLayer.cpp | 36 ++++-- tests/validation/fixtures/PoolingLayerFixture.h | 7 +- tests/validation/reference/PoolingLayer.cpp | 16 ++- 9 files changed, 203 insertions(+), 51 deletions(-) diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index e8eed67c58..946b8a6cb6 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -1210,7 +1210,8 @@ struct PoolingLayerInfo exclude_padding(false), is_global_pooling(false), fp_mixed_precision(false), - use_inf_as_limit(true) + use_inf_as_limit(true), + use_kernel_indices(false) { } /** Constructor @@ -1224,6 +1225,7 @@ struct PoolingLayerInfo * Defaults to false; * @param[in] fp_mixed_precision (Optional) Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy. * @param[in] use_inf_as_limit (Optional) Use inf to represent the limits of datatypes range, instead of using "lowest" property of the data type. + * @param[in] use_kernel_indices (Optional) Use kernel indices instead of using source indices while computing indices tensor. */ explicit PoolingLayerInfo(PoolingType pool_type, unsigned int pool_size, @@ -1231,7 +1233,8 @@ struct PoolingLayerInfo PadStrideInfo pad_stride_info = PadStrideInfo(), bool exclude_padding = false, bool fp_mixed_precision = false, - bool use_inf_as_limit = true) + bool use_inf_as_limit = true, + bool use_kernel_indices = false) : pool_type(pool_type), pool_size(Size2D(pool_size, pool_size)), data_layout(data_layout), @@ -1239,7 +1242,8 @@ struct PoolingLayerInfo exclude_padding(exclude_padding), is_global_pooling(false), fp_mixed_precision(fp_mixed_precision), - use_inf_as_limit(use_inf_as_limit) + use_inf_as_limit(use_inf_as_limit), + use_kernel_indices(use_kernel_indices) { } @@ -1254,6 +1258,7 @@ struct PoolingLayerInfo * Defaults to false; * @param[in] fp_mixed_precision (Optional) Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy. * @param[in] use_inf_as_limit (Optional) Use inf to represent the limits of datatypes range, instead of using "lowest" property of the data type. + * @param[in] use_kernel_indices (Optional) Use kernel indices instead of using source indices while computing indices tensor. */ explicit PoolingLayerInfo(PoolingType pool_type, Size2D pool_size, @@ -1261,7 +1266,8 @@ struct PoolingLayerInfo PadStrideInfo pad_stride_info = PadStrideInfo(), bool exclude_padding = false, bool fp_mixed_precision = false, - bool use_inf_as_limit = true) + bool use_inf_as_limit = true, + bool use_kernel_indices = false) : pool_type(pool_type), pool_size(pool_size), data_layout(data_layout), @@ -1269,7 +1275,8 @@ struct PoolingLayerInfo exclude_padding(exclude_padding), is_global_pooling(false), fp_mixed_precision(fp_mixed_precision), - use_inf_as_limit(use_inf_as_limit) + use_inf_as_limit(use_inf_as_limit), + use_kernel_indices(use_kernel_indices) { } @@ -1288,7 +1295,8 @@ struct PoolingLayerInfo exclude_padding(false), is_global_pooling(true), fp_mixed_precision(false), - use_inf_as_limit(true) + use_inf_as_limit(true), + use_kernel_indices(false) { } @@ -1300,6 +1308,7 @@ struct PoolingLayerInfo bool is_global_pooling; bool fp_mixed_precision; bool use_inf_as_limit; + bool use_kernel_indices; }; /** Pooling Layer Information struct*/ diff --git a/src/cpu/kernels/CpuPool2dKernel.cpp b/src/cpu/kernels/CpuPool2dKernel.cpp index 8f04812b0c..d72a41cbbe 100644 --- a/src/cpu/kernels/CpuPool2dKernel.cpp +++ b/src/cpu/kernels/CpuPool2dKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2022 Arm Limited. + * Copyright (c) 2017-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -28,17 +28,11 @@ #include "arm_compute/core/Validate.h" #include "arm_compute/core/Window.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" -#include "src/core/AccessWindowStatic.h" #include "src/core/CPP/Validate.h" -#include "src/core/NEON/NEAsymm.h" -#include "src/core/NEON/NEFixedPoint.h" -#include "src/core/NEON/NEMath.h" #include "src/core/common/Registrars.h" #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" #include "src/cpu/kernels/pool2d/neon/list.h" -#include "support/ToolchainSupport.h" - #include "src/core/NEON/wrapper/wrapper.h" #include @@ -191,7 +185,8 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst, const ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(dst, &out_info); if(indices) { - ARM_COMPUTE_RETURN_ERROR_ON_MSG((pool_size != Size2D(2, 2)), "Pooling indices only supported for pool size 2x2"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(((pool_size != Size2D(2, 2)) && !pool_info.use_kernel_indices), "Pooling indices returning source tensor coordinates is only supported for pool size 2x2"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(pool_info.use_kernel_indices && (src->data_layout() != DataLayout::NHWC), "Pooling kernel indices only supported for NHWC"); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(indices, &out_info); } } diff --git a/src/cpu/kernels/pool2d/neon/fp32.cpp b/src/cpu/kernels/pool2d/neon/fp32.cpp index 018f62b8a8..8e93df3347 100644 --- a/src/cpu/kernels/pool2d/neon/fp32.cpp +++ b/src/cpu/kernels/pool2d/neon/fp32.cpp @@ -24,7 +24,6 @@ #include "arm_compute/core/Helpers.h" #include "arm_compute/core/ITensor.h" #include "arm_compute/core/Types.h" -#include "arm_compute/core/utils/misc/Traits.h" #include "src/core/NEON/wrapper/intrinsics/intrinsics.h" #include "src/core/helpers/WindowHelpers.h" #include "src/cpu/kernels/pool2d/neon/list.h" @@ -98,10 +97,10 @@ void pooling2_f32_maxpool_indices(const ITensor *src, ITensor *dst0, ITensor *ds vst1q_f32(reinterpret_cast(out.ptr()) + x_off, vres); const uint32_t offset_base = offset_no_padding(in.offset(), id, *src->info(), pool_stride_x, pool_stride_y, DataLayout::NHWC); - const uint32_t offset_x0 = (uint32_t)offset_base / sizeof(float) + x_off; - const uint32_t offset_x1 = (uint32_t)offset_x0 + in_stride_y / sizeof(float) - pad_horizontal; - const uint32_t offset_x2 = (uint32_t)offset_x0 + in_stride_z / sizeof(float) - pad_horizontal * src->info()->tensor_shape()[1]; - const uint32_t offset_x3 = (uint32_t)offset_x2 + in_stride_y / sizeof(float) - pad_horizontal; + const uint32_t offset_x0 = offset_base / sizeof(float) + x_off; + const uint32_t offset_x1 = offset_x0 + in_stride_y / sizeof(float) - pad_horizontal; + const uint32_t offset_x2 = offset_x0 + in_stride_z / sizeof(float) - pad_horizontal * src->info()->tensor_shape()[1]; + const uint32_t offset_x3 = offset_x2 + in_stride_y / sizeof(float) - pad_horizontal; const uint32x4_t voffset_x0 = { offset_x0, offset_x0 + 1, offset_x0 + 2, offset_x0 + 3 }; const uint32x4_t voffset_x1 = { offset_x1, offset_x1 + 1, offset_x1 + 2, offset_x1 + 3 }; const uint32x4_t voffset_x2 = { offset_x2, offset_x2 + 1, offset_x2 + 2, offset_x2 + 3 }; @@ -127,10 +126,10 @@ void pooling2_f32_maxpool_indices(const ITensor *src, ITensor *dst0, ITensor *ds *(reinterpret_cast(out.ptr()) + x_off) = res; const uint32_t offset_base = offset_no_padding(in.offset(), id, *src->info(), pool_stride_x, pool_stride_y, DataLayout::NHWC); - const uint32_t offset_x0 = (uint32_t)offset_base / sizeof(float) + x_off; - const uint32_t offset_x1 = (uint32_t)offset_x0 + in_stride_y / sizeof(float) - pad_horizontal; - const uint32_t offset_x2 = (uint32_t)offset_x0 + in_stride_z / sizeof(float) - pad_horizontal * src->info()->tensor_shape()[1]; - const uint32_t offset_x3 = (uint32_t)offset_x2 + in_stride_y / sizeof(float) - pad_horizontal; + const uint32_t offset_x0 = offset_base / sizeof(float) + x_off; + const uint32_t offset_x1 = offset_x0 + in_stride_y / sizeof(float) - pad_horizontal; + const uint32_t offset_x2 = offset_x0 + in_stride_z / sizeof(float) - pad_horizontal * src->info()->tensor_shape()[1]; + const uint32_t offset_x3 = offset_x2 + in_stride_y / sizeof(float) - pad_horizontal; const uint32_t tmp_idx0 = (x0 >= x1) ? offset_x0 : offset_x1; const uint32_t tmp_idx1 = (x2 >= x3) ? offset_x2 : offset_x3; const uint32_t tmp_idx2 = (std::max(x0, x1) >= std::max(x2, x3)) ? tmp_idx0 : tmp_idx1; @@ -141,11 +140,135 @@ void pooling2_f32_maxpool_indices(const ITensor *src, ITensor *dst0, ITensor *ds }, in, out, indices); } +} // namespace + +void poolingMxN_fp32_neon_nhwc_kernel_indices(const ITensor *src, ITensor *dst0, ITensor *dst1, const PoolingLayerInfo &pool_info, const Window &window) +{ + const int window_start_x = window.x().start(); + const int window_end_x = window.x().end(); + constexpr int window_step_x = 4; + + Window window_out = window; + window_out.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator out(dst0, window_out); + Iterator indices(dst1, window_out); + + const int pool_size_x = pool_info.is_global_pooling ? src->info()->tensor_shape().y() : pool_info.pool_size.width; + const int pool_size_y = pool_info.is_global_pooling ? src->info()->tensor_shape().z() : pool_info.pool_size.height; + + const int pool_pad_top = pool_info.pad_stride_info.pad_top(); + const int pool_pad_left = pool_info.pad_stride_info.pad_left(); + + int pool_stride_x = 0; + int pool_stride_y = 0; + std::tie(pool_stride_x, pool_stride_y) = pool_info.pad_stride_info.stride(); + + const float min_value = get_initial_min(pool_info.use_inf_as_limit); + + float32x4_t vres; + uint32x4_t vidx; + + constexpr int idx_width = 1; + constexpr int idx_height = 2; + constexpr int idx_batch = 3; + + const int y_stride = static_cast(src->info()->strides_in_bytes().y()); + const int z_stride = static_cast(src->info()->strides_in_bytes().z()); + const int n_stride = static_cast(src->info()->strides_in_bytes()[idx_batch]); + + const int input_dim_w = src->info()->dimension(idx_width); + const int input_dim_h = src->info()->dimension(idx_height); + + const uint8_t *in_ptr_start = src->buffer() + src->info()->offset_first_element_in_bytes(); + + execute_window_loop(window_out, [&](const Coordinates & id) + { + const int idx_width = static_cast(id.y()) * pool_stride_x - pool_pad_left; + const int idx_height = static_cast(id.z()) * pool_stride_y - pool_pad_top; + + const int pool_start_x = std::max(0, -idx_width); + const int pool_start_y = std::max(0, -idx_height); + + const int pool_end_x = std::min(pool_size_x, input_dim_w - idx_width); + const int pool_end_y = std::min(pool_size_y, input_dim_h - idx_height); + + const uint8_t *in_ptr_n = in_ptr_start + id[idx_batch] * n_stride; + + const int in_ptr_y_offset = (z_stride * idx_height) + (pool_start_y * z_stride); + const int in_ptr_x_offset = (y_stride * idx_width) + (pool_start_x * y_stride); + + int x_off = window_start_x; + + for(; x_off <= (window_end_x - window_step_x); x_off += window_step_x) + { + vres = vdupq_n_f32(min_value); + vidx = vdupq_n_u32(0U); + const uint8_t *in_ptr_y = in_ptr_n + in_ptr_y_offset + in_ptr_x_offset; + uint32_t curr_kernel_index = pool_size_x * pool_start_y; + for(int y = pool_start_y; y < pool_end_y; ++y) + { + const uint8_t *in_ptr_x = in_ptr_y + (x_off * sizeof(float)); + curr_kernel_index += pool_start_x; + for(int x = pool_start_x; x < pool_end_x; ++x) + { + const float32x4_t data = vld1q_f32(reinterpret_cast(in_ptr_x)); + const uint32x4_t vidx_curr = vdupq_n_u32(curr_kernel_index); + const uint32x4_t idxMask = vcgtq_f32(data, vres); + vidx = vbslq_u32(idxMask, vidx_curr, vidx); + vres = vmaxq_f32(vres, data); + in_ptr_x += y_stride; + curr_kernel_index++; + } + curr_kernel_index += (pool_size_x - pool_end_x); + in_ptr_y += z_stride; + } + // Store result + vst1q_f32(reinterpret_cast(out.ptr()) + x_off, vres); + vst1q_u32(reinterpret_cast(indices.ptr()) + x_off, vidx); + } + + // Left-overs loop + for(; x_off < window_end_x; ++x_off) + { + float res = min_value; + uint32_t idx = 0U; + const uint8_t *in_ptr_y = in_ptr_n + in_ptr_y_offset + in_ptr_x_offset; + uint32_t curr_kernel_index = pool_size_x * pool_start_y; + for(int y = pool_start_y; y < pool_end_y; ++y) + { + const uint8_t *in_ptr_x = in_ptr_y + (x_off * sizeof(float)); + curr_kernel_index += pool_start_x; + for(int x = pool_start_x; x < pool_end_x; ++x) + { + const float data = *(reinterpret_cast(in_ptr_x)); + if(data > res) + { + idx = pool_size_x * y + x; + res = data; + } + in_ptr_x += y_stride; + curr_kernel_index++; + } + curr_kernel_index += (pool_size_x - pool_end_x); + in_ptr_y += z_stride; + } + + // Store result + *(reinterpret_cast(out.ptr()) + x_off) = res; + *(reinterpret_cast(indices.ptr()) + x_off) = idx; + } + }, + out, indices); } void poolingMxN_fp32_neon_nhwc(const ITensor *src, ITensor *dst0, ITensor *dst1, PoolingLayerInfo &pool_info, const Window &window_src, const Window &window) { - if(pool_info.pool_size == Size2D(2, 2) && pool_info.pool_type == PoolingType::MAX && dst1) + if((pool_info.pool_type == PoolingType::MAX) && pool_info.use_kernel_indices && (dst1 != nullptr)) + { + poolingMxN_fp32_neon_nhwc_kernel_indices(src, dst0, dst1, pool_info, window); + } + else if(pool_info.pool_size == Size2D(2, 2) && pool_info.pool_type == PoolingType::MAX && !pool_info.pad_stride_info.has_padding() && (dst1 != nullptr)) { pooling2_f32_maxpool_indices(src, dst0, dst1, pool_info, window_src, window); } diff --git a/src/cpu/operators/CpuPool2d.cpp b/src/cpu/operators/CpuPool2d.cpp index eabbd5e0cc..722cd36ee5 100644 --- a/src/cpu/operators/CpuPool2d.cpp +++ b/src/cpu/operators/CpuPool2d.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -27,7 +27,6 @@ #include "arm_compute/core/TensorInfo.h" #include "arm_compute/runtime/NEON/NEScheduler.h" #include "src/common/utils/Log.h" -#include "src/core/NEON/kernels/NEFillBorderKernel.h" #include "src/cpu/kernels/CpuPool2dKernel.h" #include "src/cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.h" @@ -41,6 +40,7 @@ CpuPool2d::CpuPool2d() : _pooling_layer_kernel(), _asm_glue(), _is_global_pooling_layer(false), + _use_kernel_indices(false), _data_layout(DataLayout::NCHW), _aux_mem(1) { @@ -62,6 +62,7 @@ void CpuPool2d::configure(ITensorInfo *src, ITensorInfo *dst, const PoolingLayer const unsigned int idx_width = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH); const unsigned int idx_height = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT); _is_global_pooling_layer = (src->dimension(idx_width) == pool_info.pool_size.width) && (src->dimension(idx_height) == pool_info.pool_size.height); + _use_kernel_indices = pool_info.use_kernel_indices; if(run_optimised) { @@ -117,7 +118,7 @@ void CpuPool2d::run(ITensorPack &tensors) NEScheduler::get().schedule_op(_pooling_layer_kernel.get(), _is_global_pooling_layer ? Window::DimZ : Window::DimY, _pooling_layer_kernel->window(), tensors); break; case DataLayout::NHWC: - NEScheduler::get().schedule_op(_pooling_layer_kernel.get(), Window::DimX, _pooling_layer_kernel->window(), tensors); + NEScheduler::get().schedule_op(_pooling_layer_kernel.get(), (_use_kernel_indices ? Window::DimY : Window::DimX), _pooling_layer_kernel->window(), tensors); break; default: ARM_COMPUTE_ERROR("Data layout not supported"); diff --git a/src/cpu/operators/CpuPool2d.h b/src/cpu/operators/CpuPool2d.h index 02c2609a6a..5c571db88a 100644 --- a/src/cpu/operators/CpuPool2d.h +++ b/src/cpu/operators/CpuPool2d.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -76,6 +76,7 @@ private: std::unique_ptr _asm_glue; bool _is_global_pooling_layer; + bool _use_kernel_indices; DataLayout _data_layout; experimental::MemoryRequirements _aux_mem{}; }; diff --git a/tests/validation/CL/PoolingLayer.cpp b/tests/validation/CL/PoolingLayer.cpp index f17021671c..9fe28c7acf 100644 --- a/tests/validation/CL/PoolingLayer.cpp +++ b/tests/validation/CL/PoolingLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -181,11 +181,11 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLPoolingLayerFixture, framework::Datase validate(CLAccessor(_target), _reference, tolerance_f32); } -FIXTURE_DATA_TEST_CASE(RunSmallIndices, CLPoolingLayerIndicesFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallNoneUnitShapes(), +FIXTURE_DATA_TEST_CASE(RunSmallIndices, CLPoolingLayerIndicesFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallNoneUnitShapes(), combine(PoolingLayerDatasetFPIndicesSmall, framework::dataset::make("DataType", DataType::F32))), - pool_data_layout_dataset)) + pool_data_layout_dataset),framework::dataset::make("UseKernelIndices", { false }))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f32); @@ -250,11 +250,11 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLMixedPrecesionPoolingLayerFixture, fram // Validate output validate(CLAccessor(_target), _reference, tolerance_f16); } -FIXTURE_DATA_TEST_CASE(RunSmallIndices, CLPoolingLayerIndicesFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallNoneUnitShapes(), +FIXTURE_DATA_TEST_CASE(RunSmallIndices, CLPoolingLayerIndicesFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallNoneUnitShapes(), combine(PoolingLayerDatasetFPIndicesSmall, framework::dataset::make("DataType", DataType::F16))), - pool_data_layout_dataset)) + pool_data_layout_dataset), framework::dataset::make("UseKernelIndices", { false }))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f32); diff --git a/tests/validation/NEON/PoolingLayer.cpp b/tests/validation/NEON/PoolingLayer.cpp index 457610f2bd..3acd453ea2 100644 --- a/tests/validation/NEON/PoolingLayer.cpp +++ b/tests/validation/NEON/PoolingLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -24,7 +24,6 @@ #include "arm_compute/core/Types.h" #include "arm_compute/runtime/NEON/functions/NEPoolingLayer.h" #include "arm_compute/runtime/Tensor.h" -#include "arm_compute/runtime/TensorAllocator.h" #include "tests/NEON/Accessor.h" #include "tests/PaddingCalculator.h" #include "tests/datasets/PoolingLayerDataset.h" @@ -150,13 +149,26 @@ using NESpecialPoolingLayerFixture = SpecialPoolingLayerValidationFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallNoneUnitShapes(), combine(PoolingLayerIndicesDatasetFPSmall, - framework::dataset::make("DataType", - DataType::F32))), - framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC }))) +FIXTURE_DATA_TEST_CASE(RunIndices, NEPoolingLayerIndicesFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallNoneUnitShapes(), + combine(PoolingLayerIndicesDatasetFPSmall, + framework::dataset::make("DataType", DataType::F32))), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), + framework::dataset::make("UseKernelIndices", { false }))) +{ + // Validate output + validate(Accessor(_target), _reference, tolerance_f32); + validate(Accessor(_target_indices), _ref_indices); +} +FIXTURE_DATA_TEST_CASE(RunKernelIndices, NEPoolingLayerIndicesFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallNoneUnitShapes(), + combine(PoolingLayerKernelIndicesDatasetFPSmall, + framework::dataset::make("DataType", DataType::F32))), + framework::dataset::make("DataLayout", { DataLayout::NHWC })), + framework::dataset::make("UseKernelIndices", { true }))) { // Validate output validate(Accessor(_target), _reference, tolerance_f32); @@ -208,10 +220,12 @@ TEST_SUITE_END() // FP32 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC TEST_SUITE(FP16) -FIXTURE_DATA_TEST_CASE(RunIndices, NEPoolingLayerIndicesFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallNoneUnitShapes(), combine(PoolingLayerIndicesDatasetFPSmall, - framework::dataset::make("DataType", - DataType::F16))), - framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC }))) +FIXTURE_DATA_TEST_CASE(RunIndices, NEPoolingLayerIndicesFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallNoneUnitShapes(), + combine(PoolingLayerIndicesDatasetFPSmall, + framework::dataset::make("DataType", + DataType::F16))), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), + framework::dataset::make("UseKernelIndices", { false }))) { // Validate output validate(Accessor(_target), _reference, tolerance_f16); diff --git a/tests/validation/fixtures/PoolingLayerFixture.h b/tests/validation/fixtures/PoolingLayerFixture.h index 6e9edfbb5d..f34aaa8bfa 100644 --- a/tests/validation/fixtures/PoolingLayerFixture.h +++ b/tests/validation/fixtures/PoolingLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -162,9 +162,10 @@ class PoolingLayerIndicesValidationFixture : public PoolingLayerValidationGeneri { public: template - void setup(TensorShape shape, PoolingType pool_type, Size2D pool_size, PadStrideInfo pad_stride_info, bool exclude_padding, DataType data_type, DataLayout data_layout) + void setup(TensorShape shape, PoolingType pool_type, Size2D pool_size, PadStrideInfo pad_stride_info, bool exclude_padding, DataType data_type, DataLayout data_layout, bool use_kernel_indices) { - PoolingLayerValidationGenericFixture::setup(shape, PoolingLayerInfo(pool_type, pool_size, data_layout, pad_stride_info, exclude_padding), + PoolingLayerValidationGenericFixture::setup(shape, PoolingLayerInfo(pool_type, pool_size, data_layout, pad_stride_info, exclude_padding, false, + true, use_kernel_indices), data_type, data_layout, true); } }; diff --git a/tests/validation/reference/PoolingLayer.cpp b/tests/validation/reference/PoolingLayer.cpp index 6a358ced0c..bf7bd0c1df 100644 --- a/tests/validation/reference/PoolingLayer.cpp +++ b/tests/validation/reference/PoolingLayer.cpp @@ -83,20 +83,28 @@ SimpleTensor pooling_layer_internal(const SimpleTensor &src, const Pooling { int wstart = w * pool_stride_x - pad_left; int hstart = h * pool_stride_y - pad_top; + + // Used to calculate kernel indices + int kh_start = std::max(0, -hstart); + int kw_start = std::max(0, -wstart); + int max_ker_index{ 0 }; + int wend = std::min(wstart + pool_size_x, w_src); int hend = std::min(hstart + pool_size_y, h_src); wstart = std::max(wstart, 0); hstart = std::max(hstart, 0); auto max_val = info.use_inf_as_limit ? -std::numeric_limits::infinity() : std::numeric_limits::lowest(); int max_index{ 0 }; - for(int y = hstart; y < hend; ++y) + + for(int y = hstart, kh = kh_start; y < hend; ++y, ++kh) { - for(int x = wstart; x < wend; ++x) + for(int x = wstart, kw = kw_start; x < wend; ++x, ++kw) { const auto val = static_cast(src[b * z_src * h_src * w_src + r * h_src * w_src + y * w_src + x]); if(val > max_val) { - max_val = val; + max_val = val; + max_ker_index = pool_size_x * (kh) + (kw); if(data_layout == DataLayout::NCHW) { max_index = coord2index(src.shape(), Coordinates(x, y, r, 0)); @@ -112,7 +120,7 @@ SimpleTensor pooling_layer_internal(const SimpleTensor &src, const Pooling dst[b * z_dst * h_dst * w_dst + r * h_dst * w_dst + h * w_dst + w] = static_cast(max_val); if(indices) { - (*indices)[b * z_dst * h_dst * w_dst + r * h_dst * w_dst + h * w_dst + w] = max_index; + (*indices)[b * z_dst * h_dst * w_dst + r * h_dst * w_dst + h * w_dst + w] = (info.use_kernel_indices) ? max_ker_index : max_index; } } } -- cgit v1.2.1