aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NESelectKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NESelectKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NESelectKernel.cpp156
1 files changed, 63 insertions, 93 deletions
diff --git a/src/core/NEON/kernels/NESelectKernel.cpp b/src/core/NEON/kernels/NESelectKernel.cpp
index b8c9b244ee..7789b828ea 100644
--- a/src/core/NEON/kernels/NESelectKernel.cpp
+++ b/src/core/NEON/kernels/NESelectKernel.cpp
@@ -29,13 +29,12 @@
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Validate.h"
+
+#include "src/core/common/Registrars.h"
#include "src/core/CPP/Validate.h"
-#include "src/core/NEON/wrapper/wrapper.h"
#include "src/core/helpers/AutoConfiguration.h"
#include "src/core/helpers/WindowHelpers.h"
-
-#include "src/core/common/Registrars.h"
-
+#include "src/core/NEON/wrapper/wrapper.h"
#include "src/cpu/kernels/select/list.h"
#include <arm_neon.h>
@@ -54,7 +53,8 @@ struct SelectKernelSelectorData
};
using SelectorPtr = std::add_pointer<bool(const SelectKernelSelectorData &data)>::type;
-using KernelPtr = std::add_pointer<void(const ITensor *, const ITensor *, const ITensor *, ITensor *, const Window &)>::type;
+using KernelPtr =
+ std::add_pointer<void(const ITensor *, const ITensor *, const ITensor *, ITensor *, const Window &)>::type;
struct SelectKernelSelector
{
@@ -63,95 +63,62 @@ struct SelectKernelSelector
KernelPtr ukernel;
};
-static const SelectKernelSelector available_kernels[] =
-{
- {
- "neon_s8_same_rank",
- [](const SelectKernelSelectorData & data) { return data.dt == DataType::S8 && data.is_same_rank == true; },
- REGISTER_INTEGER_NEON(arm_compute::cpu::neon_s8_select_same_rank)
- },
- {
- "neon_s16_same_rank",
- [](const SelectKernelSelectorData & data) { return data.dt == DataType::S16 && data.is_same_rank == true; },
- REGISTER_INTEGER_NEON(arm_compute::cpu::neon_s16_select_same_rank)
- },
- {
- "neon_s32_same_rank",
- [](const SelectKernelSelectorData & data) { return data.dt == DataType::S32 && data.is_same_rank == true; },
- REGISTER_INTEGER_NEON(arm_compute::cpu::neon_s32_select_same_rank)
- },
- {
- "neon_u8_same_rank",
- [](const SelectKernelSelectorData & data) { return data.dt == DataType::U8 && data.is_same_rank == true; },
- REGISTER_INTEGER_NEON(arm_compute::cpu::neon_u8_select_same_rank)
- },
- {
- "neon_u16_same_rank",
- [](const SelectKernelSelectorData & data) { return data.dt == DataType::U16 && data.is_same_rank == true; },
- REGISTER_INTEGER_NEON(arm_compute::cpu::neon_u16_select_same_rank)
- },
- {
- "neon_u32_same_rank",
- [](const SelectKernelSelectorData & data) { return data.dt == DataType::U32 && data.is_same_rank == true; },
- REGISTER_INTEGER_NEON(arm_compute::cpu::neon_u32_select_same_rank)
- },
- {
- "neon_s8_not_same_rank",
- [](const SelectKernelSelectorData & data) { return data.dt == DataType::S8 && data.is_same_rank == false; },
- REGISTER_INTEGER_NEON(arm_compute::cpu::neon_s8_select_not_same_rank)
- },
- {
- "neon_s16_not_same_rank",
- [](const SelectKernelSelectorData & data) { return data.dt == DataType::S16 && data.is_same_rank == false; },
- REGISTER_INTEGER_NEON(arm_compute::cpu::neon_s16_select_not_same_rank)
- },
- {
- "neon_s32_not_same_rank",
- [](const SelectKernelSelectorData & data) { return data.dt == DataType::S32 && data.is_same_rank == false; },
- REGISTER_INTEGER_NEON(arm_compute::cpu::neon_s32_select_not_same_rank)
- },
- {
- "neon_u8_not_same_rank",
- [](const SelectKernelSelectorData & data) { return data.dt == DataType::U8 && data.is_same_rank == false; },
- REGISTER_INTEGER_NEON(arm_compute::cpu::neon_u8_select_not_same_rank)
- },
- {
- "neon_u16_not_same_rank",
- [](const SelectKernelSelectorData & data) { return data.dt == DataType::U16 && data.is_same_rank == false; },
- REGISTER_INTEGER_NEON(arm_compute::cpu::neon_u16_select_not_same_rank)
- },
- {
- "neon_u32_not_same_rank",
- [](const SelectKernelSelectorData & data) { return data.dt == DataType::U32 && data.is_same_rank == false; },
- REGISTER_INTEGER_NEON(arm_compute::cpu::neon_u32_select_not_same_rank)
- },
- {
- "neon_f16_same_rank",
- [](const SelectKernelSelectorData & data) { return data.dt == DataType::F16 && data.is_same_rank == true; },
- REGISTER_FP16_NEON(arm_compute::cpu::neon_f16_select_same_rank)
- },
- {
- "neon_f16_not_same_rank",
- [](const SelectKernelSelectorData & data) { return data.dt == DataType::F16 && data.is_same_rank == false; },
- REGISTER_FP16_NEON(arm_compute::cpu::neon_f16_select_not_same_rank)
- },
- {
- "neon_f32_same_rank",
- [](const SelectKernelSelectorData & data) { return data.dt == DataType::F32 && data.is_same_rank == true; },
- REGISTER_FP32_NEON(arm_compute::cpu::neon_f32_select_same_rank)
- },
- {
- "neon_f32_not_same_rank",
- [](const SelectKernelSelectorData & data) { return data.dt == DataType::F32 && data.is_same_rank == false; },
- REGISTER_FP32_NEON(arm_compute::cpu::neon_f32_select_not_same_rank)
- },
+static const SelectKernelSelector available_kernels[] = {
+ {"neon_s8_same_rank",
+ [](const SelectKernelSelectorData &data) { return data.dt == DataType::S8 && data.is_same_rank == true; },
+ REGISTER_INTEGER_NEON(arm_compute::cpu::neon_s8_select_same_rank)},
+ {"neon_s16_same_rank",
+ [](const SelectKernelSelectorData &data) { return data.dt == DataType::S16 && data.is_same_rank == true; },
+ REGISTER_INTEGER_NEON(arm_compute::cpu::neon_s16_select_same_rank)},
+ {"neon_s32_same_rank",
+ [](const SelectKernelSelectorData &data) { return data.dt == DataType::S32 && data.is_same_rank == true; },
+ REGISTER_INTEGER_NEON(arm_compute::cpu::neon_s32_select_same_rank)},
+ {"neon_u8_same_rank",
+ [](const SelectKernelSelectorData &data) { return data.dt == DataType::U8 && data.is_same_rank == true; },
+ REGISTER_INTEGER_NEON(arm_compute::cpu::neon_u8_select_same_rank)},
+ {"neon_u16_same_rank",
+ [](const SelectKernelSelectorData &data) { return data.dt == DataType::U16 && data.is_same_rank == true; },
+ REGISTER_INTEGER_NEON(arm_compute::cpu::neon_u16_select_same_rank)},
+ {"neon_u32_same_rank",
+ [](const SelectKernelSelectorData &data) { return data.dt == DataType::U32 && data.is_same_rank == true; },
+ REGISTER_INTEGER_NEON(arm_compute::cpu::neon_u32_select_same_rank)},
+ {"neon_s8_not_same_rank",
+ [](const SelectKernelSelectorData &data) { return data.dt == DataType::S8 && data.is_same_rank == false; },
+ REGISTER_INTEGER_NEON(arm_compute::cpu::neon_s8_select_not_same_rank)},
+ {"neon_s16_not_same_rank",
+ [](const SelectKernelSelectorData &data) { return data.dt == DataType::S16 && data.is_same_rank == false; },
+ REGISTER_INTEGER_NEON(arm_compute::cpu::neon_s16_select_not_same_rank)},
+ {"neon_s32_not_same_rank",
+ [](const SelectKernelSelectorData &data) { return data.dt == DataType::S32 && data.is_same_rank == false; },
+ REGISTER_INTEGER_NEON(arm_compute::cpu::neon_s32_select_not_same_rank)},
+ {"neon_u8_not_same_rank",
+ [](const SelectKernelSelectorData &data) { return data.dt == DataType::U8 && data.is_same_rank == false; },
+ REGISTER_INTEGER_NEON(arm_compute::cpu::neon_u8_select_not_same_rank)},
+ {"neon_u16_not_same_rank",
+ [](const SelectKernelSelectorData &data) { return data.dt == DataType::U16 && data.is_same_rank == false; },
+ REGISTER_INTEGER_NEON(arm_compute::cpu::neon_u16_select_not_same_rank)},
+ {"neon_u32_not_same_rank",
+ [](const SelectKernelSelectorData &data) { return data.dt == DataType::U32 && data.is_same_rank == false; },
+ REGISTER_INTEGER_NEON(arm_compute::cpu::neon_u32_select_not_same_rank)},
+ {"neon_f16_same_rank",
+ [](const SelectKernelSelectorData &data) { return data.dt == DataType::F16 && data.is_same_rank == true; },
+ REGISTER_FP16_NEON(arm_compute::cpu::neon_f16_select_same_rank)},
+ {"neon_f16_not_same_rank",
+ [](const SelectKernelSelectorData &data) { return data.dt == DataType::F16 && data.is_same_rank == false; },
+ REGISTER_FP16_NEON(arm_compute::cpu::neon_f16_select_not_same_rank)},
+ {"neon_f32_same_rank",
+ [](const SelectKernelSelectorData &data) { return data.dt == DataType::F32 && data.is_same_rank == true; },
+ REGISTER_FP32_NEON(arm_compute::cpu::neon_f32_select_same_rank)},
+ {"neon_f32_not_same_rank",
+ [](const SelectKernelSelectorData &data) { return data.dt == DataType::F32 && data.is_same_rank == false; },
+ REGISTER_FP32_NEON(arm_compute::cpu::neon_f32_select_not_same_rank)},
};
const SelectKernelSelector *get_implementation(const SelectKernelSelectorData &data)
{
- for(const auto &uk : available_kernels)
+ for (const auto &uk : available_kernels)
{
- if(uk.is_selected(data))
+ if (uk.is_selected(data))
{
return &uk;
}
@@ -184,7 +151,8 @@ void NESelectKernel::configure(const ITensor *c, const ITensor *x, const ITensor
INEKernel::configure(win);
}
-Status NESelectKernel::validate(const ITensorInfo *c, const ITensorInfo *x, const ITensorInfo *y, const ITensorInfo *output)
+Status
+NESelectKernel::validate(const ITensorInfo *c, const ITensorInfo *x, const ITensorInfo *y, const ITensorInfo *output)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(c, x, y);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(x);
@@ -195,9 +163,11 @@ Status NESelectKernel::validate(const ITensorInfo *c, const ITensorInfo *x, cons
const bool is_same_rank = (c->tensor_shape().num_dimensions() == x->tensor_shape().num_dimensions());
ARM_COMPUTE_RETURN_ERROR_ON(is_same_rank && (x->tensor_shape() != c->tensor_shape()));
- ARM_COMPUTE_RETURN_ERROR_ON(!is_same_rank && ((c->tensor_shape().num_dimensions() > 1) || (c->tensor_shape().x() != x->tensor_shape()[x->tensor_shape().num_dimensions() - 1])));
+ ARM_COMPUTE_RETURN_ERROR_ON(!is_same_rank &&
+ ((c->tensor_shape().num_dimensions() > 1) ||
+ (c->tensor_shape().x() != x->tensor_shape()[x->tensor_shape().num_dimensions() - 1])));
- if(output != nullptr && output->total_size() != 0)
+ if (output != nullptr && output->total_size() != 0)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(x, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(x, output);
@@ -214,7 +184,7 @@ void NESelectKernel::run(const Window &window, const ThreadInfo &info)
ARM_COMPUTE_ERROR_ON(_output == nullptr);
ARM_COMPUTE_ERROR_ON(_output->info() == nullptr);
- const auto *uk = get_implementation(SelectKernelSelectorData{ _output->info()->data_type(), _has_same_rank });
+ const auto *uk = get_implementation(SelectKernelSelectorData{_output->info()->data_type(), _has_same_rank});
ARM_COMPUTE_ERROR_ON(uk == nullptr);
ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr);
uk->ukernel(_c, _x, _y, _output, window);