diff options
Diffstat (limited to 'src/core/NEON/kernels/NEElementwiseUnaryKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/NEElementwiseUnaryKernel.cpp | 148 |
1 files changed, 60 insertions, 88 deletions
diff --git a/src/core/NEON/kernels/NEElementwiseUnaryKernel.cpp b/src/core/NEON/kernels/NEElementwiseUnaryKernel.cpp index d899643fdc..ed1cb6fca4 100644 --- a/src/core/NEON/kernels/NEElementwiseUnaryKernel.cpp +++ b/src/core/NEON/kernels/NEElementwiseUnaryKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -28,7 +28,10 @@ #include "arm_compute/core/ITensor.h" #include "arm_compute/core/Validate.h" #include "src/core/CPP/Validate.h" -#include "src/core/NEON/wrapper/wrapper.h" +#include "src/core/NEON/kernels/elementwise/impl/elementwise_unary_list.h" +#include "src/core/SVE/kernels/elementwise/impl/elementwise_unary_list.h" +#include "src/core/common/Registrars.h" +#include "src/core/common/StdTypes.h" #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" #include "support/ToolchainSupport.h" @@ -37,85 +40,65 @@ namespace arm_compute { namespace { -template <typename ScalarType> -inline ScalarType elementwise_op_scalar_imp(ElementWiseUnary op, const ScalarType &a) +using ElementwiseUnarySelector = std::add_pointer<bool(DataType)>::type; + +struct ElementwiseUnaryKernel { - switch(op) - { - case ElementWiseUnary::RSQRT: - return 1 / sqrt(a); - case ElementWiseUnary::EXP: - return std::exp(a); - case ElementWiseUnary::NEG: - return -a; - case ElementWiseUnary::LOG: - return std::log(a); - case ElementWiseUnary::ABS: - return std::abs(a); - case ElementWiseUnary::ROUND: - return support::cpp11::nearbyint(a); - case ElementWiseUnary::SIN: - return std::sin(a); - default: - ARM_COMPUTE_ERROR("NOT_SUPPORTED!"); - } -} + const char *name; + const ElementwiseUnarySelector is_selected; + NEElementwiseUnaryKernel::ElementwiseUnaryUkernelPtr ukernel; +}; -template <typename ScalarType, typename VectorType> -inline VectorType elementwise_op_imp(ElementWiseUnary op, const VectorType &a) +static const ElementwiseUnaryKernel available_kernels[] = { - switch(op) +#if defined(__ARM_FEATURE_SVE) { - case ElementWiseUnary::RSQRT: - return wrapper::vinvsqrt(a); - case ElementWiseUnary::EXP: - return wrapper::vexpq(a); - case ElementWiseUnary::NEG: - return wrapper::vneg(a); - case ElementWiseUnary::LOG: - return wrapper::vlog(a); - case ElementWiseUnary::ABS: - return wrapper::vabs(a); - case ElementWiseUnary::ROUND: - return wrapper::vround(a); - case ElementWiseUnary::SIN: - return wrapper::vsin(a); - default: - ARM_COMPUTE_ERROR("NOT_SUPPORTED!"); - } -} -} // namespace + "fp32_sve_elementwise_unary", + [](DataType dt) { return dt == DataType::F32; }, + REGISTER_FP32_SVE(arm_compute::cpu::elementwise_sve_op<f32>), + }, + { + "fp16_sve_elementwise_unary", + [](DataType dt) { return dt == DataType::F16; }, + REGISTER_FP16_SVE(arm_compute::cpu::elementwise_sve_op<f16>), + }, + { + "s32_sve_elementwise_unary", + [](DataType dt) { return dt == DataType::S32; }, + REGISTER_INTEGER_SVE(arm_compute::cpu::elementwise_sve_op<s32>), + }, +#endif // defined(__ARM_FEATURE_SVE) + { + "fp32_neon_elementwise_unary", + [](DataType dt) { return dt == DataType::F32; }, + REGISTER_FP32_NEON(arm_compute::cpu::elementwise_op<f32>), + }, +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + { + "fp16_neon_elementwise_unary", + [](DataType dt) { return dt == DataType::F16; }, + REGISTER_FP32_NEON(arm_compute::cpu::elementwise_op<f16>), + }, +#endif // defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + { + "s32_neon_elementwise_unary", + [](DataType dt) { return dt == DataType::S32; }, + REGISTER_INTEGER_NEON(arm_compute::cpu::elementwise_op<s32>), + }, +}; -template <typename ScalarType> -void NEElementwiseUnaryKernel::elementwise_op(const Window &window) +const ElementwiseUnaryKernel *get_implementation(DataType dt) { - const int window_step_x = 16 / sizeof(ScalarType); - const auto window_start_x = static_cast<int>(window.x().start()); - const auto window_end_x = static_cast<int>(window.x().end()); - - Window win = window; - win.set(Window::DimX, Window::Dimension(0, 1, 1)); - - Iterator input(_input, win); - Iterator output(_output, win); - - execute_window_loop(win, [&](const Coordinates &) + for(const auto &uk : available_kernels) { - auto output_ptr = reinterpret_cast<ScalarType *>(output.ptr()); - const auto input_ptr = reinterpret_cast<const ScalarType *>(input.ptr()); - - int x = window_start_x; - for(; x <= window_end_x - window_step_x; x += window_step_x) + if(uk.is_selected(dt)) { - wrapper::vstore(output_ptr + x, elementwise_op_imp<ScalarType>(_op, wrapper::vloadq(input_ptr + x))); + return &uk; } - for(; x < window_end_x; ++x) - { - *(output_ptr + x) = elementwise_op_scalar_imp(_op, *(input_ptr + x)); - } - }, - input, output); + } + return nullptr; } +} // namespace NEElementwiseUnaryKernel::NEElementwiseUnaryKernel() : _func(nullptr), _input(nullptr), _output(nullptr), _op() @@ -143,28 +126,17 @@ void NEElementwiseUnaryKernel::configure(ElementWiseUnary op, const ITensor *inp INEKernel::configure(win); - switch(input->info()->data_type()) - { - case DataType::F32: - _func = &NEElementwiseUnaryKernel::elementwise_op<float>; - break; -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - case DataType::F16: - _func = &NEElementwiseUnaryKernel::elementwise_op<float16_t>; -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - break; - case DataType::S32: - _func = &NEElementwiseUnaryKernel::elementwise_op<int32_t>; - break; - default: - ARM_COMPUTE_ERROR("DataType not supported"); - } + _func = get_implementation(input->info()->data_type())->ukernel; } Status NEElementwiseUnaryKernel::validate(ElementWiseUnary op, const ITensorInfo *input, const ITensorInfo *output) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output); ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input); + + const auto *uk = get_implementation(input->data_type()); + ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr); + switch(op) { case ElementWiseUnary::EXP: @@ -196,6 +168,6 @@ void NEElementwiseUnaryKernel::run(const Window &window, const ThreadInfo &info) ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window); ARM_COMPUTE_ERROR_ON(_func == nullptr); - (this->*_func)(window); + (*_func)(_input, _output, window, _op); } } // namespace arm_compute |