aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEElementwiseOperationKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NEElementwiseOperationKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEElementwiseOperationKernel.cpp38
1 files changed, 38 insertions, 0 deletions
diff --git a/src/core/NEON/kernels/NEElementwiseOperationKernel.cpp b/src/core/NEON/kernels/NEElementwiseOperationKernel.cpp
index 6b87ea017b..33457e1fca 100644
--- a/src/core/NEON/kernels/NEElementwiseOperationKernel.cpp
+++ b/src/core/NEON/kernels/NEElementwiseOperationKernel.cpp
@@ -130,6 +130,11 @@ inline ScalarType elementwise_arithm_op_scalar(const ScalarType &a, const Scalar
res = a / b;
break;
}
+ case ArithmeticOperation::POWER:
+ {
+ res = std::pow(a, b);
+ break;
+ }
default:
ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
}
@@ -174,12 +179,24 @@ inline float32x4_t elementwise_arithm_op<ArithmeticOperation::DIV, float32x4_t>(
return wrapper::vdiv(a, b);
}
+template <>
+inline float32x4_t elementwise_arithm_op<ArithmeticOperation::POWER, float32x4_t>(const float32x4_t &a, const float32x4_t &b)
+{
+ return wrapper::vpow(a, b);
+}
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <>
inline float16x8_t elementwise_arithm_op<ArithmeticOperation::DIV, float16x8_t>(const float16x8_t &a, const float16x8_t &b)
{
return wrapper::vdiv(a, b);
}
+
+template <>
+inline float16x8_t elementwise_arithm_op<ArithmeticOperation::POWER, float16x8_t>(const float16x8_t &a, const float16x8_t &b)
+{
+ return wrapper::vpow(a, b);
+}
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <ArithmeticOperation op>
@@ -879,6 +896,27 @@ Status NEDivisionOperationKernel::validate(const ITensorInfo *input1, const ITen
return Status{};
}
+/** The power operator */
+void NEPowerOperationKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output)
+{
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1->info(), *input2->info(), *output->info()));
+ configure_common(input1, input2, output);
+ _function = configure_arithm_func<ArithmeticOperation::POWER>(input1, input2, output);
+}
+
+Status NEPowerOperationKernel::validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::F16, DataType::F32);
+ return NEArithmeticOperationKernel::validate_arguments(input1, input2, output);
+}
+
+Status NEPowerOperationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output));
+ return Status{};
+}
+
/** Comparison operators (equal, not equal, less than, greater than, less than or equal, greater than or equal) */
void NEComparisonOperationKernel::configure(ComparisonOperation op, const ITensor *input1, const ITensor *input2, ITensor *output)