diff options
Diffstat (limited to 'src/core/NEON/kernels/NEElementwiseUnaryKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/NEElementwiseUnaryKernel.cpp | 53 |
1 files changed, 42 insertions, 11 deletions
diff --git a/src/core/NEON/kernels/NEElementwiseUnaryKernel.cpp b/src/core/NEON/kernels/NEElementwiseUnaryKernel.cpp index 34696d872a..d62b165727 100644 --- a/src/core/NEON/kernels/NEElementwiseUnaryKernel.cpp +++ b/src/core/NEON/kernels/NEElementwiseUnaryKernel.cpp @@ -55,12 +55,15 @@ inline ScalarType elementwise_op_scalar(const ScalarType &a) return 1 / sqrt(a); case ElementWiseUnary::EXP: return std::exp(a); + case ElementWiseUnary::NEG: + return -a; default: ARM_COMPUTE_ERROR("NOT_SUPPORTED!"); } } -template <ElementWiseUnary op, typename VectorType> +/* Elementwise operations that are supported for float */ +template <ElementWiseUnary op, bool is_float, typename VectorType, typename std::enable_if<is_float, int>::type = 0> inline VectorType elementwise_op(const VectorType &a) { switch(op) @@ -69,12 +72,27 @@ inline VectorType elementwise_op(const VectorType &a) return wrapper::vinvsqrt(a); case ElementWiseUnary::EXP: return wrapper::vexpq(a); + case ElementWiseUnary::NEG: + return wrapper::vneg(a); default: ARM_COMPUTE_ERROR("NOT_SUPPORTED!"); } } -template <ElementWiseUnary op, typename ScalarType> +/* Elementwise operations that are supported for non floats */ +template <ElementWiseUnary op, bool is_float, typename VectorType, typename std::enable_if<!is_float, int>::type = 0> +inline VectorType elementwise_op(const VectorType &a) +{ + switch(op) + { + case ElementWiseUnary::NEG: + return wrapper::vneg(a); + default: + ARM_COMPUTE_ERROR("NOT_SUPPORTED!"); + } +} + +template <ElementWiseUnary op, typename ScalarType, bool is_float> void elementwise_op(const ITensor *in, ITensor *out, const Window &window) { const int window_step_x = 16 / sizeof(ScalarType); @@ -95,7 +113,7 @@ void elementwise_op(const ITensor *in, ITensor *out, const Window &window) int x = window_start_x; for(; x <= window_end_x - window_step_x; x += window_step_x) { - wrapper::vstore(output_ptr + x, elementwise_op<op>(wrapper::vloadq(input_ptr + x))); + wrapper::vstore(output_ptr + x, elementwise_op<op, is_float>(wrapper::vloadq(input_ptr + x))); } for(; x < window_end_x; ++x) { @@ -115,10 +133,11 @@ configure_func(const ITensor *input, ITensor *output) static std::map<std::string, NEElementwiseUnaryKernel::ElementwiseUnaryFunction *> map_function = { - { "op_F32_F32", &elementwise_op<op, float> } + { "op_F32_F32", &elementwise_op<op, float, true> }, + { "op_S32_S32", &elementwise_op<op, int32_t, false> }, }; #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - map_function["op_F16_F16"] = &elementwise_op<op, float16_t>; + map_function["op_F16_F16"] = &elementwise_op<op, float16_t, true>; #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ auto it = map_function.find(function_to_call); @@ -142,7 +161,7 @@ NEElementwiseUnaryKernel::NEElementwiseUnaryKernel() void NEElementwiseUnaryKernel::configure(ElementWiseUnary op, const ITensor *input, ITensor *output) { - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input->info(), *output->info())); + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(op, *input->info(), *output->info())); ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); // Configure kernel window @@ -168,16 +187,29 @@ void NEElementwiseUnaryKernel::configure(ElementWiseUnary op, const ITensor *inp case ElementWiseUnary::EXP: _function = configure_func<ElementWiseUnary::EXP>(input, output); break; + case ElementWiseUnary::NEG: + _function = configure_func<ElementWiseUnary::NEG>(input, output); + break; default: ARM_COMPUTE_ERROR("NOT_SUPPORTED!"); } } -Status NEElementwiseUnaryKernel::validate_arguments(const ITensorInfo &input, const ITensorInfo &output) +Status NEElementwiseUnaryKernel::validate_arguments(ElementWiseUnary op, const ITensorInfo &input, const ITensorInfo &output) { ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::F16, DataType::F32); - + switch(op) + { + case ElementWiseUnary::EXP: + case ElementWiseUnary::RSQRT: + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::F16, DataType::F32); + break; + case ElementWiseUnary::NEG: + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::F16, DataType::F32, DataType::S32); + break; + default: + ARM_COMPUTE_ERROR("NOT_SUPPORTED!"); + } // Validate in case of configured output if(output.total_size() > 0) { @@ -189,9 +221,8 @@ Status NEElementwiseUnaryKernel::validate_arguments(const ITensorInfo &input, co Status NEElementwiseUnaryKernel::validate(ElementWiseUnary op, const ITensorInfo *input, const ITensorInfo *output) { - ARM_COMPUTE_UNUSED(op); ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output); - ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input, *output)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(op, *input, *output)); return Status{}; } |