diff options
Diffstat (limited to 'src/core/NEON/kernels/NESelectKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/NESelectKernel.cpp | 156 |
1 files changed, 63 insertions, 93 deletions
diff --git a/src/core/NEON/kernels/NESelectKernel.cpp b/src/core/NEON/kernels/NESelectKernel.cpp index b8c9b244ee..7789b828ea 100644 --- a/src/core/NEON/kernels/NESelectKernel.cpp +++ b/src/core/NEON/kernels/NESelectKernel.cpp @@ -29,13 +29,12 @@ #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/Validate.h" + +#include "src/core/common/Registrars.h" #include "src/core/CPP/Validate.h" -#include "src/core/NEON/wrapper/wrapper.h" #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" - -#include "src/core/common/Registrars.h" - +#include "src/core/NEON/wrapper/wrapper.h" #include "src/cpu/kernels/select/list.h" #include <arm_neon.h> @@ -54,7 +53,8 @@ struct SelectKernelSelectorData }; using SelectorPtr = std::add_pointer<bool(const SelectKernelSelectorData &data)>::type; -using KernelPtr = std::add_pointer<void(const ITensor *, const ITensor *, const ITensor *, ITensor *, const Window &)>::type; +using KernelPtr = + std::add_pointer<void(const ITensor *, const ITensor *, const ITensor *, ITensor *, const Window &)>::type; struct SelectKernelSelector { @@ -63,95 +63,62 @@ struct SelectKernelSelector KernelPtr ukernel; }; -static const SelectKernelSelector available_kernels[] = -{ - { - "neon_s8_same_rank", - [](const SelectKernelSelectorData & data) { return data.dt == DataType::S8 && data.is_same_rank == true; }, - REGISTER_INTEGER_NEON(arm_compute::cpu::neon_s8_select_same_rank) - }, - { - "neon_s16_same_rank", - [](const SelectKernelSelectorData & data) { return data.dt == DataType::S16 && data.is_same_rank == true; }, - REGISTER_INTEGER_NEON(arm_compute::cpu::neon_s16_select_same_rank) - }, - { - "neon_s32_same_rank", - [](const SelectKernelSelectorData & data) { return data.dt == DataType::S32 && data.is_same_rank == true; }, - REGISTER_INTEGER_NEON(arm_compute::cpu::neon_s32_select_same_rank) - }, - { - "neon_u8_same_rank", - [](const SelectKernelSelectorData & data) { return data.dt == DataType::U8 && data.is_same_rank == true; }, - REGISTER_INTEGER_NEON(arm_compute::cpu::neon_u8_select_same_rank) - }, - { - "neon_u16_same_rank", - [](const SelectKernelSelectorData & data) { return data.dt == DataType::U16 && data.is_same_rank == true; }, - REGISTER_INTEGER_NEON(arm_compute::cpu::neon_u16_select_same_rank) - }, - { - "neon_u32_same_rank", - [](const SelectKernelSelectorData & data) { return data.dt == DataType::U32 && data.is_same_rank == true; }, - REGISTER_INTEGER_NEON(arm_compute::cpu::neon_u32_select_same_rank) - }, - { - "neon_s8_not_same_rank", - [](const SelectKernelSelectorData & data) { return data.dt == DataType::S8 && data.is_same_rank == false; }, - REGISTER_INTEGER_NEON(arm_compute::cpu::neon_s8_select_not_same_rank) - }, - { - "neon_s16_not_same_rank", - [](const SelectKernelSelectorData & data) { return data.dt == DataType::S16 && data.is_same_rank == false; }, - REGISTER_INTEGER_NEON(arm_compute::cpu::neon_s16_select_not_same_rank) - }, - { - "neon_s32_not_same_rank", - [](const SelectKernelSelectorData & data) { return data.dt == DataType::S32 && data.is_same_rank == false; }, - REGISTER_INTEGER_NEON(arm_compute::cpu::neon_s32_select_not_same_rank) - }, - { - "neon_u8_not_same_rank", - [](const SelectKernelSelectorData & data) { return data.dt == DataType::U8 && data.is_same_rank == false; }, - REGISTER_INTEGER_NEON(arm_compute::cpu::neon_u8_select_not_same_rank) - }, - { - "neon_u16_not_same_rank", - [](const SelectKernelSelectorData & data) { return data.dt == DataType::U16 && data.is_same_rank == false; }, - REGISTER_INTEGER_NEON(arm_compute::cpu::neon_u16_select_not_same_rank) - }, - { - "neon_u32_not_same_rank", - [](const SelectKernelSelectorData & data) { return data.dt == DataType::U32 && data.is_same_rank == false; }, - REGISTER_INTEGER_NEON(arm_compute::cpu::neon_u32_select_not_same_rank) - }, - { - "neon_f16_same_rank", - [](const SelectKernelSelectorData & data) { return data.dt == DataType::F16 && data.is_same_rank == true; }, - REGISTER_FP16_NEON(arm_compute::cpu::neon_f16_select_same_rank) - }, - { - "neon_f16_not_same_rank", - [](const SelectKernelSelectorData & data) { return data.dt == DataType::F16 && data.is_same_rank == false; }, - REGISTER_FP16_NEON(arm_compute::cpu::neon_f16_select_not_same_rank) - }, - { - "neon_f32_same_rank", - [](const SelectKernelSelectorData & data) { return data.dt == DataType::F32 && data.is_same_rank == true; }, - REGISTER_FP32_NEON(arm_compute::cpu::neon_f32_select_same_rank) - }, - { - "neon_f32_not_same_rank", - [](const SelectKernelSelectorData & data) { return data.dt == DataType::F32 && data.is_same_rank == false; }, - REGISTER_FP32_NEON(arm_compute::cpu::neon_f32_select_not_same_rank) - }, +static const SelectKernelSelector available_kernels[] = { + {"neon_s8_same_rank", + [](const SelectKernelSelectorData &data) { return data.dt == DataType::S8 && data.is_same_rank == true; }, + REGISTER_INTEGER_NEON(arm_compute::cpu::neon_s8_select_same_rank)}, + {"neon_s16_same_rank", + [](const SelectKernelSelectorData &data) { return data.dt == DataType::S16 && data.is_same_rank == true; }, + REGISTER_INTEGER_NEON(arm_compute::cpu::neon_s16_select_same_rank)}, + {"neon_s32_same_rank", + [](const SelectKernelSelectorData &data) { return data.dt == DataType::S32 && data.is_same_rank == true; }, + REGISTER_INTEGER_NEON(arm_compute::cpu::neon_s32_select_same_rank)}, + {"neon_u8_same_rank", + [](const SelectKernelSelectorData &data) { return data.dt == DataType::U8 && data.is_same_rank == true; }, + REGISTER_INTEGER_NEON(arm_compute::cpu::neon_u8_select_same_rank)}, + {"neon_u16_same_rank", + [](const SelectKernelSelectorData &data) { return data.dt == DataType::U16 && data.is_same_rank == true; }, + REGISTER_INTEGER_NEON(arm_compute::cpu::neon_u16_select_same_rank)}, + {"neon_u32_same_rank", + [](const SelectKernelSelectorData &data) { return data.dt == DataType::U32 && data.is_same_rank == true; }, + REGISTER_INTEGER_NEON(arm_compute::cpu::neon_u32_select_same_rank)}, + {"neon_s8_not_same_rank", + [](const SelectKernelSelectorData &data) { return data.dt == DataType::S8 && data.is_same_rank == false; }, + REGISTER_INTEGER_NEON(arm_compute::cpu::neon_s8_select_not_same_rank)}, + {"neon_s16_not_same_rank", + [](const SelectKernelSelectorData &data) { return data.dt == DataType::S16 && data.is_same_rank == false; }, + REGISTER_INTEGER_NEON(arm_compute::cpu::neon_s16_select_not_same_rank)}, + {"neon_s32_not_same_rank", + [](const SelectKernelSelectorData &data) { return data.dt == DataType::S32 && data.is_same_rank == false; }, + REGISTER_INTEGER_NEON(arm_compute::cpu::neon_s32_select_not_same_rank)}, + {"neon_u8_not_same_rank", + [](const SelectKernelSelectorData &data) { return data.dt == DataType::U8 && data.is_same_rank == false; }, + REGISTER_INTEGER_NEON(arm_compute::cpu::neon_u8_select_not_same_rank)}, + {"neon_u16_not_same_rank", + [](const SelectKernelSelectorData &data) { return data.dt == DataType::U16 && data.is_same_rank == false; }, + REGISTER_INTEGER_NEON(arm_compute::cpu::neon_u16_select_not_same_rank)}, + {"neon_u32_not_same_rank", + [](const SelectKernelSelectorData &data) { return data.dt == DataType::U32 && data.is_same_rank == false; }, + REGISTER_INTEGER_NEON(arm_compute::cpu::neon_u32_select_not_same_rank)}, + {"neon_f16_same_rank", + [](const SelectKernelSelectorData &data) { return data.dt == DataType::F16 && data.is_same_rank == true; }, + REGISTER_FP16_NEON(arm_compute::cpu::neon_f16_select_same_rank)}, + {"neon_f16_not_same_rank", + [](const SelectKernelSelectorData &data) { return data.dt == DataType::F16 && data.is_same_rank == false; }, + REGISTER_FP16_NEON(arm_compute::cpu::neon_f16_select_not_same_rank)}, + {"neon_f32_same_rank", + [](const SelectKernelSelectorData &data) { return data.dt == DataType::F32 && data.is_same_rank == true; }, + REGISTER_FP32_NEON(arm_compute::cpu::neon_f32_select_same_rank)}, + {"neon_f32_not_same_rank", + [](const SelectKernelSelectorData &data) { return data.dt == DataType::F32 && data.is_same_rank == false; }, + REGISTER_FP32_NEON(arm_compute::cpu::neon_f32_select_not_same_rank)}, }; const SelectKernelSelector *get_implementation(const SelectKernelSelectorData &data) { - for(const auto &uk : available_kernels) + for (const auto &uk : available_kernels) { - if(uk.is_selected(data)) + if (uk.is_selected(data)) { return &uk; } @@ -184,7 +151,8 @@ void NESelectKernel::configure(const ITensor *c, const ITensor *x, const ITensor INEKernel::configure(win); } -Status NESelectKernel::validate(const ITensorInfo *c, const ITensorInfo *x, const ITensorInfo *y, const ITensorInfo *output) +Status +NESelectKernel::validate(const ITensorInfo *c, const ITensorInfo *x, const ITensorInfo *y, const ITensorInfo *output) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(c, x, y); ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(x); @@ -195,9 +163,11 @@ Status NESelectKernel::validate(const ITensorInfo *c, const ITensorInfo *x, cons const bool is_same_rank = (c->tensor_shape().num_dimensions() == x->tensor_shape().num_dimensions()); ARM_COMPUTE_RETURN_ERROR_ON(is_same_rank && (x->tensor_shape() != c->tensor_shape())); - ARM_COMPUTE_RETURN_ERROR_ON(!is_same_rank && ((c->tensor_shape().num_dimensions() > 1) || (c->tensor_shape().x() != x->tensor_shape()[x->tensor_shape().num_dimensions() - 1]))); + ARM_COMPUTE_RETURN_ERROR_ON(!is_same_rank && + ((c->tensor_shape().num_dimensions() > 1) || + (c->tensor_shape().x() != x->tensor_shape()[x->tensor_shape().num_dimensions() - 1]))); - if(output != nullptr && output->total_size() != 0) + if (output != nullptr && output->total_size() != 0) { ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(x, output); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(x, output); @@ -214,7 +184,7 @@ void NESelectKernel::run(const Window &window, const ThreadInfo &info) ARM_COMPUTE_ERROR_ON(_output == nullptr); ARM_COMPUTE_ERROR_ON(_output->info() == nullptr); - const auto *uk = get_implementation(SelectKernelSelectorData{ _output->info()->data_type(), _has_same_rank }); + const auto *uk = get_implementation(SelectKernelSelectorData{_output->info()->data_type(), _has_same_rank}); ARM_COMPUTE_ERROR_ON(uk == nullptr); ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr); uk->ukernel(_c, _x, _y, _output, window); |