From f6e475c9a092bc6e0fb53f484fbf2832183a9c44 Mon Sep 17 00:00:00 2001 From: Usama Arif Date: Fri, 10 May 2019 12:06:28 +0100 Subject: COMPMID-2268: Implement NEG for NEON. Change-Id: I90c023dbea8ea12e9af677294ba576b2bfcc02a4 Signed-off-by: Usama Arif Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/184216 Tested-by: bsgcomp Comments-Addressed: bsgcomp Reviewed-by: Pablo Tello Reviewed-on: https://review.mlplatform.org/c/1099 Tested-by: Arm Jenkins Reviewed-by: Pablo Marquez Comments-Addressed: Arm Jenkins --- src/core/NEON/kernels/NEElementwiseUnaryKernel.cpp | 53 +++++++++++++++++----- 1 file changed, 42 insertions(+), 11 deletions(-) (limited to 'src/core/NEON/kernels/NEElementwiseUnaryKernel.cpp') diff --git a/src/core/NEON/kernels/NEElementwiseUnaryKernel.cpp b/src/core/NEON/kernels/NEElementwiseUnaryKernel.cpp index 34696d872a..d62b165727 100644 --- a/src/core/NEON/kernels/NEElementwiseUnaryKernel.cpp +++ b/src/core/NEON/kernels/NEElementwiseUnaryKernel.cpp @@ -55,12 +55,15 @@ inline ScalarType elementwise_op_scalar(const ScalarType &a) return 1 / sqrt(a); case ElementWiseUnary::EXP: return std::exp(a); + case ElementWiseUnary::NEG: + return -a; default: ARM_COMPUTE_ERROR("NOT_SUPPORTED!"); } } -template +/* Elementwise operations that are supported for float */ +template ::type = 0> inline VectorType elementwise_op(const VectorType &a) { switch(op) @@ -69,12 +72,27 @@ inline VectorType elementwise_op(const VectorType &a) return wrapper::vinvsqrt(a); case ElementWiseUnary::EXP: return wrapper::vexpq(a); + case ElementWiseUnary::NEG: + return wrapper::vneg(a); default: ARM_COMPUTE_ERROR("NOT_SUPPORTED!"); } } -template +/* Elementwise operations that are supported for non floats */ +template ::type = 0> +inline VectorType elementwise_op(const VectorType &a) +{ + switch(op) + { + case ElementWiseUnary::NEG: + return wrapper::vneg(a); + default: + ARM_COMPUTE_ERROR("NOT_SUPPORTED!"); + } +} + +template void elementwise_op(const ITensor *in, ITensor *out, const Window &window) { const int window_step_x = 16 / sizeof(ScalarType); @@ -95,7 +113,7 @@ void elementwise_op(const ITensor *in, ITensor *out, const Window &window) int x = window_start_x; for(; x <= window_end_x - window_step_x; x += window_step_x) { - wrapper::vstore(output_ptr + x, elementwise_op(wrapper::vloadq(input_ptr + x))); + wrapper::vstore(output_ptr + x, elementwise_op(wrapper::vloadq(input_ptr + x))); } for(; x < window_end_x; ++x) { @@ -115,10 +133,11 @@ configure_func(const ITensor *input, ITensor *output) static std::map map_function = { - { "op_F32_F32", &elementwise_op } + { "op_F32_F32", &elementwise_op }, + { "op_S32_S32", &elementwise_op }, }; #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - map_function["op_F16_F16"] = &elementwise_op; + map_function["op_F16_F16"] = &elementwise_op; #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ auto it = map_function.find(function_to_call); @@ -142,7 +161,7 @@ NEElementwiseUnaryKernel::NEElementwiseUnaryKernel() void NEElementwiseUnaryKernel::configure(ElementWiseUnary op, const ITensor *input, ITensor *output) { - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input->info(), *output->info())); + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(op, *input->info(), *output->info())); ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); // Configure kernel window @@ -168,16 +187,29 @@ void NEElementwiseUnaryKernel::configure(ElementWiseUnary op, const ITensor *inp case ElementWiseUnary::EXP: _function = configure_func(input, output); break; + case ElementWiseUnary::NEG: + _function = configure_func(input, output); + break; default: ARM_COMPUTE_ERROR("NOT_SUPPORTED!"); } } -Status NEElementwiseUnaryKernel::validate_arguments(const ITensorInfo &input, const ITensorInfo &output) +Status NEElementwiseUnaryKernel::validate_arguments(ElementWiseUnary op, const ITensorInfo &input, const ITensorInfo &output) { ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::F16, DataType::F32); - + switch(op) + { + case ElementWiseUnary::EXP: + case ElementWiseUnary::RSQRT: + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::F16, DataType::F32); + break; + case ElementWiseUnary::NEG: + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::F16, DataType::F32, DataType::S32); + break; + default: + ARM_COMPUTE_ERROR("NOT_SUPPORTED!"); + } // Validate in case of configured output if(output.total_size() > 0) { @@ -189,9 +221,8 @@ Status NEElementwiseUnaryKernel::validate_arguments(const ITensorInfo &input, co Status NEElementwiseUnaryKernel::validate(ElementWiseUnary op, const ITensorInfo *input, const ITensorInfo *output) { - ARM_COMPUTE_UNUSED(op); ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output); - ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input, *output)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(op, *input, *output)); return Status{}; } -- cgit v1.2.1