From ff4fca0d2ae523557a7b31db2014b48391f1d8c3 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 2 Oct 2020 21:00:00 +0100 Subject: COMPMID-3684: Use case data type decoupling Decouples data types for NEFloorKernel Signed-off-by: Georgios Pinitas Change-Id: I6756300540bc5ef32a9990246eed8619a76855f2 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4084 Reviewed-by: Giorgio Arena Reviewed-by: Gian Marco Iodice Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins --- src/core/NEON/kernels/NEFloorKernel.cpp | 113 +++++++++++++++----------------- 1 file changed, 53 insertions(+), 60 deletions(-) (limited to 'src/core/NEON/kernels/NEFloorKernel.cpp') diff --git a/src/core/NEON/kernels/NEFloorKernel.cpp b/src/core/NEON/kernels/NEFloorKernel.cpp index e134097f7a..301dc7a422 100644 --- a/src/core/NEON/kernels/NEFloorKernel.cpp +++ b/src/core/NEON/kernels/NEFloorKernel.cpp @@ -26,23 +26,63 @@ #include "arm_compute/core/CPP/Validate.h" #include "arm_compute/core/Coordinates.h" #include "arm_compute/core/Helpers.h" -#include "arm_compute/core/IAccessWindow.h" #include "arm_compute/core/ITensor.h" #include "arm_compute/core/NEON/INEKernel.h" #include "arm_compute/core/Validate.h" -#include "src/core/NEON/NEMath.h" -#include +#include "src/core/NEON/kernels/floor/impl/list.h" +#include "src/core/common/Registrars.h" namespace arm_compute { namespace { +struct FloorSelectorData +{ + DataType dt; +}; +using FloorSelectorPtr = std::add_pointer::type; +using FloorUKernelPtr = std::add_pointer::type; + +struct FloorKernel +{ + const char *name; + const FloorSelectorPtr is_selected; + FloorUKernelPtr ukernel; +}; + +static const FloorKernel available_kernels[] = +{ + { + "fp16_neon_floor", + [](const FloorSelectorData & data) { return data.dt == DataType::F16; }, + REGISTER_FP16_NEON(arm_compute::cpu::fp16_neon_floor) + }, + { + "f32_neon_floor", + [](const FloorSelectorData & data) { return data.dt == DataType::F32; }, + REGISTER_FP32_NEON(arm_compute::cpu::fp32_neon_floor) + }, +}; + +const FloorKernel *get_implementation(const FloorSelectorData &data) +{ + for(const auto &uk : available_kernels) + { + if(uk.is_selected(data)) + { + return &uk; + } + } + return nullptr; +} + Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output); - ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32); + + const auto *uk = get_implementation(FloorSelectorData{ input->data_type() }); + ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr); // Validate in case of configured output if(output->total_size() > 0) @@ -90,66 +130,19 @@ void NEFloorKernel::run(const Window &window, const ThreadInfo &info) ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window); - const DataType data_type = _input->info()->data_type(); - - const auto window_start_x = static_cast(window.x().start()); - const auto window_end_x = static_cast(window.x().end()); - const int window_step_x = 16 / _input->info()->element_size(); - Window win{ window }; win.set(Window::DimX, Window::Dimension(0, 1, 1)); + + const auto len = static_cast(window.x().end()) - static_cast(window.x().start()); + const auto *uk = get_implementation(FloorSelectorData{ _input->info()->data_type() }); + Iterator input(_input, win); Iterator output(_output, win); - if(data_type == DataType::F32) + execute_window_loop(win, [&](const Coordinates &) { - execute_window_loop(win, [&](const Coordinates &) - { - const auto input_ptr = reinterpret_cast(input.ptr()); - const auto output_ptr = reinterpret_cast(output.ptr()); - - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const float32x4_t res = vfloorq_f32(vld1q_f32(input_ptr + x)); - vst1q_f32(output_ptr + x, res); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(output_ptr + x) = std::floor(*(input_ptr + x)); - } - }, - input, output); - } -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - else if(data_type == DataType::F16) - { - execute_window_loop(win, [&](const Coordinates &) - { - const auto input_ptr = reinterpret_cast(input.ptr()); - const auto output_ptr = reinterpret_cast(output.ptr()); - - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const float16x8_t res = vfloorq_f16(vld1q_f16(input_ptr + x)); - vst1q_f16(output_ptr + x, res); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(output_ptr + x) = std::floor(*(input_ptr + x)); - } - }, - input, output); - } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - else - { - ARM_COMPUTE_ERROR("Invalid data type!"); - } + uk->ukernel(input.ptr(), output.ptr(), len); + }, + input, output); } } // namespace arm_compute -- cgit v1.2.1