diff options
Diffstat (limited to 'src/cpu/kernels/CpuPool2dKernel.cpp')
-rw-r--r-- | src/cpu/kernels/CpuPool2dKernel.cpp | 89 |
1 files changed, 30 insertions, 59 deletions
diff --git a/src/cpu/kernels/CpuPool2dKernel.cpp b/src/cpu/kernels/CpuPool2dKernel.cpp index f61cd0835d..953a9ffb67 100644 --- a/src/cpu/kernels/CpuPool2dKernel.cpp +++ b/src/cpu/kernels/CpuPool2dKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -52,136 +52,101 @@ namespace { using namespace misc::shape_calculator; -struct PoolingSelectorData -{ - DataType dt; - DataLayout dl; - int pool_stride_x; - Size2D pool_size; -}; - -using PoolingSelectorPtr = std::add_pointer<bool(const PoolingSelectorData &data)>::type; -using PoolingKernelPtr = std::add_pointer<void(const ITensor *, ITensor *, ITensor *, PoolingLayerInfo &, const Window &, const Window &)>::type; -struct PoolingKernel -{ - const char *name; - const PoolingSelectorPtr is_selected; - PoolingKernelPtr ukernel; -}; - -static const PoolingKernel available_kernels[] = +static const std::vector<CpuPool2dKernel::PoolingKernel> available_kernels = { { "neon_qu8_nhwc_poolMxN", - [](const PoolingSelectorData & data) { return ((data.dl == DataLayout::NHWC) && (data.dt == DataType::QASYMM8)); }, + [](const PoolDataTypeISASelectorData & data) { return ((data.dl == DataLayout::NHWC) && (data.dt == DataType::QASYMM8)); }, REGISTER_QASYMM8_NEON(arm_compute::cpu::poolingMxN_qasymm8_neon_nhwc) }, { "neon_qs8_nhwc_poolMxN", - [](const PoolingSelectorData & data) { return ((data.dl == DataLayout::NHWC) && (data.dt == DataType::QASYMM8_SIGNED)); }, + [](const PoolDataTypeISASelectorData & data) { return ((data.dl == DataLayout::NHWC) && (data.dt == DataType::QASYMM8_SIGNED)); }, REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::poolingMxN_qasymm8_signed_neon_nhwc) }, #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) { "neon_f16_nhwc_poolMxN", - [](const PoolingSelectorData & data) { return ((data.dl == DataLayout::NHWC) && (data.dt == DataType::F16)); }, + [](const PoolDataTypeISASelectorData & data) { return ((data.dl == DataLayout::NHWC) && (data.dt == DataType::F16)); }, REGISTER_FP16_NEON(arm_compute::cpu::poolingMxN_fp16_neon_nhwc) }, #endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */ { "neon_fp32_nhwc_poolMxN", - [](const PoolingSelectorData & data) { return ((data.dl == DataLayout::NHWC) && (data.dt == DataType::F32)); }, + [](const PoolDataTypeISASelectorData & data) { return ((data.dl == DataLayout::NHWC) && (data.dt == DataType::F32)); }, REGISTER_FP32_NEON(arm_compute::cpu::poolingMxN_fp32_neon_nhwc) }, #if defined(ENABLE_NCHW_KERNELS) { "neon_qu8_nchw_pool2", - [](const PoolingSelectorData & data) { return ((data.dl == DataLayout::NCHW) && (data.dt == DataType::QASYMM8) && (data.pool_size.x() == data.pool_size.y()) && (data.pool_size.x() == 2) && (data.pool_stride_x < 3)); }, + [](const PoolDataTypeISASelectorData & data) { return ((data.dl == DataLayout::NCHW) && (data.dt == DataType::QASYMM8) && (data.pool_size.x() == data.pool_size.y()) && (data.pool_size.x() == 2) && (data.pool_stride_x < 3)); }, REGISTER_QASYMM8_NEON(arm_compute::cpu::pooling2_quantized_neon_nchw<uint8_t>) }, { "neon_qu8_nchw_pool3", - [](const PoolingSelectorData & data) { return ((data.dl == DataLayout::NCHW) && (data.dt == DataType::QASYMM8) && (data.pool_size.x() == data.pool_size.y()) && (data.pool_size.x() == 3) && (data.pool_stride_x < 3)); }, + [](const PoolDataTypeISASelectorData & data) { return ((data.dl == DataLayout::NCHW) && (data.dt == DataType::QASYMM8) && (data.pool_size.x() == data.pool_size.y()) && (data.pool_size.x() == 3) && (data.pool_stride_x < 3)); }, REGISTER_QASYMM8_NEON(arm_compute::cpu::pooling3_quantized_neon_nchw<uint8_t>) }, { "neon_qu8_nchw_poolMxN", - [](const PoolingSelectorData & data) { return ((data.dl == DataLayout::NCHW) && (data.dt == DataType::QASYMM8)); }, + [](const PoolDataTypeISASelectorData & data) { return ((data.dl == DataLayout::NCHW) && (data.dt == DataType::QASYMM8)); }, REGISTER_QASYMM8_NEON(arm_compute::cpu::poolingMxN_quantized_neon_nchw<uint8_t>) }, { "neon_qs8_nchw_pool2", - [](const PoolingSelectorData & data) { return ((data.dl == DataLayout::NCHW) && (data.dt == DataType::QASYMM8_SIGNED) && (data.pool_size.x() == data.pool_size.y()) && (data.pool_size.x() == 2) && (data.pool_stride_x < 3)); }, + [](const PoolDataTypeISASelectorData & data) { return ((data.dl == DataLayout::NCHW) && (data.dt == DataType::QASYMM8_SIGNED) && (data.pool_size.x() == data.pool_size.y()) && (data.pool_size.x() == 2) && (data.pool_stride_x < 3)); }, REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::pooling2_quantized_neon_nchw<int8_t>) }, { "neon_qs8_nchw_pool3", - [](const PoolingSelectorData & data) { return ((data.dl == DataLayout::NCHW) && (data.dt == DataType::QASYMM8_SIGNED) && (data.pool_size.x() == data.pool_size.y()) && (data.pool_size.x() == 3) && (data.pool_stride_x < 3)); }, + [](const PoolDataTypeISASelectorData & data) { return ((data.dl == DataLayout::NCHW) && (data.dt == DataType::QASYMM8_SIGNED) && (data.pool_size.x() == data.pool_size.y()) && (data.pool_size.x() == 3) && (data.pool_stride_x < 3)); }, REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::pooling3_quantized_neon_nchw<int8_t>) }, { "neon_qs8_nchw_poolMxN", - [](const PoolingSelectorData & data) { return ((data.dl == DataLayout::NCHW) && (data.dt == DataType::QASYMM8_SIGNED)); }, + [](const PoolDataTypeISASelectorData & data) { return ((data.dl == DataLayout::NCHW) && (data.dt == DataType::QASYMM8_SIGNED)); }, REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::poolingMxN_quantized_neon_nchw<int8_t>) }, #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) { "neon_fp16_nchw_pool2", - [](const PoolingSelectorData & data) { return ((data.dl == DataLayout::NCHW) && (data.dt == DataType::F16) && (data.pool_size.x() == data.pool_size.y()) && (data.pool_size.x() == 2)); }, + [](const PoolDataTypeISASelectorData & data) { return ((data.dl == DataLayout::NCHW) && (data.dt == DataType::F16 && data.isa.fp16) && (data.pool_size.x() == data.pool_size.y()) && (data.pool_size.x() == 2)); }, REGISTER_FP16_NEON(arm_compute::cpu::pooling2_fp16_neon_nchw) }, { "neon_fp16_nchw_pool3", - [](const PoolingSelectorData & data) { return ((data.dl == DataLayout::NCHW) && (data.dt == DataType::F16) && (data.pool_size.x() == data.pool_size.y()) && (data.pool_size.x() == 3)); }, + [](const PoolDataTypeISASelectorData & data) { return ((data.dl == DataLayout::NCHW) && (data.dt == DataType::F16 && data.isa.fp16) && (data.pool_size.x() == data.pool_size.y()) && (data.pool_size.x() == 3)); }, REGISTER_FP16_NEON(arm_compute::cpu::pooling3_fp16_neon_nchw) }, { "neon_fp16_nchw_poolMxN", - [](const PoolingSelectorData & data) { return ((data.dl == DataLayout::NCHW) && (data.dt == DataType::F16)); }, + [](const PoolDataTypeISASelectorData & data) { return ((data.dl == DataLayout::NCHW) && (data.dt == DataType::F16 && data.isa.fp16)); }, REGISTER_FP16_NEON(arm_compute::cpu::poolingMxN_fp16_neon_nchw) }, #endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */ { "neon_fp32_nchw_pool2", - [](const PoolingSelectorData & data) { return ((data.dl == DataLayout::NCHW) && (data.dt == DataType::F32) && (data.pool_size.x() == data.pool_size.y()) && (data.pool_size.x() == 2)); }, + [](const PoolDataTypeISASelectorData & data) { return ((data.dl == DataLayout::NCHW) && (data.dt == DataType::F32) && (data.pool_size.x() == data.pool_size.y()) && (data.pool_size.x() == 2)); }, REGISTER_FP32_NEON(arm_compute::cpu::pooling2_fp32_neon_nchw) }, { "neon_fp32_nchw_pool3", - [](const PoolingSelectorData & data) { return ((data.dl == DataLayout::NCHW) && (data.dt == DataType::F32) && (data.pool_size.x() == data.pool_size.y()) && (data.pool_size.x() == 3)); }, + [](const PoolDataTypeISASelectorData & data) { return ((data.dl == DataLayout::NCHW) && (data.dt == DataType::F32) && (data.pool_size.x() == data.pool_size.y()) && (data.pool_size.x() == 3)); }, REGISTER_FP32_NEON(arm_compute::cpu::pooling3_fp32_neon_nchw) }, { "neon_fp32_nchw_pool7", - [](const PoolingSelectorData & data) { return ((data.dl == DataLayout::NCHW) && (data.dt == DataType::F32) && (data.pool_size.x() == data.pool_size.y()) && (data.pool_size.x() == 7)); }, + [](const PoolDataTypeISASelectorData & data) { return ((data.dl == DataLayout::NCHW) && (data.dt == DataType::F32) && (data.pool_size.x() == data.pool_size.y()) && (data.pool_size.x() == 7)); }, REGISTER_FP32_NEON(arm_compute::cpu::pooling7_fp32_neon_nchw) }, { "neon_fp32_nchw_poolMxN", - [](const PoolingSelectorData & data) { return ((data.dl == DataLayout::NCHW) && (data.dt == DataType::F32)); }, + [](const PoolDataTypeISASelectorData & data) { return ((data.dl == DataLayout::NCHW) && (data.dt == DataType::F32)); }, REGISTER_FP32_NEON(arm_compute::cpu::poolingMxN_fp32_neon_nchw) }, #endif /* defined(ENABLE_NCHW_KERNELS) */ }; -/** Micro-kernel selector - * - * @param[in] data Selection data passed to help pick the appropriate micro-kernel - * - * @return A matching micro-kernel else nullptr - */ -const PoolingKernel *get_implementation(DataType dt, DataLayout dl, int pool_stride_x, Size2D pool_size) -{ - for(const auto &uk : available_kernels) - { - if(uk.is_selected({ dt, dl, pool_stride_x, pool_size })) - { - return &uk; - } - } - return nullptr; -} - Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst, const PoolingLayerInfo &pool_info, const ITensorInfo *indices, Size2D pool_size) { @@ -235,7 +200,7 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst, const } } - const auto *uk = get_implementation(src->data_type(), src->data_layout(), pool_stride_x, pool_size); + const auto *uk = CpuPool2dKernel::get_implementation(PoolDataTypeISASelectorData{ src->data_type(), src->data_layout(), pool_stride_x, pool_size, CPUInfo::get().get_isa() }); ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr); return Status{}; @@ -335,7 +300,7 @@ void CpuPool2dKernel::configure(ITensorInfo *src, ITensorInfo *dst, const Poolin // Perform validation step ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src, dst, pool_info, indices, pool_size)); - const auto *uk = get_implementation(src->data_type(), src->data_layout(), pad_stride_info.stride().first, pool_size); + const auto *uk = CpuPool2dKernel::get_implementation(PoolDataTypeISASelectorData{ src->data_type(), src->data_layout(), (int)pad_stride_info.stride().first, pool_size, CPUInfo::get().get_isa() }); ARM_COMPUTE_ERROR_ON(uk == nullptr); // Set instance variables @@ -350,7 +315,7 @@ void CpuPool2dKernel::configure(ITensorInfo *src, ITensorInfo *dst, const Poolin { // Configure kernel window Window win = calculate_max_window(*dst, Steps()); - ICpuKernel::configure(win); + NewICpuKernel::configure(win); } else { @@ -358,7 +323,7 @@ void CpuPool2dKernel::configure(ITensorInfo *src, ITensorInfo *dst, const Poolin auto win_config = validate_and_configure_window(src, dst, indices, pool_info, _num_elems_processed_per_iteration, pool_size.x(), pool_size.y()); ARM_COMPUTE_ERROR_THROW_ON(win_config.first); - ICpuKernel::configure(win_config.second); + NewICpuKernel::configure(win_config.second); } } @@ -391,7 +356,7 @@ void CpuPool2dKernel::run_op(ITensorPack &tensors, const Window &window, const T { ARM_COMPUTE_UNUSED(info); ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); - ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel::window(), window); + ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(NewICpuKernel::window(), window); ARM_COMPUTE_ERROR_ON(_run_method == nullptr); const ITensor *src = tensors.get_const_tensor(TensorType::ACL_SRC_0); @@ -447,6 +412,12 @@ const char *CpuPool2dKernel::name() const { return _name.c_str(); } + +const std::vector<CpuPool2dKernel::PoolingKernel> &CpuPool2dKernel::get_available_kernels() +{ + return available_kernels; +} + } // namespace kernels } // namespace cpu } // namespace arm_compute |