diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/core/NEON/kernels/NEElementwiseOperationKernel.cpp | 42 | ||||
-rw-r--r-- | src/runtime/NEON/functions/NEElementwiseOperators.cpp | 12 |
2 files changed, 53 insertions, 1 deletions
diff --git a/src/core/NEON/kernels/NEElementwiseOperationKernel.cpp b/src/core/NEON/kernels/NEElementwiseOperationKernel.cpp index 88fd730554..99a3b5ac66 100644 --- a/src/core/NEON/kernels/NEElementwiseOperationKernel.cpp +++ b/src/core/NEON/kernels/NEElementwiseOperationKernel.cpp @@ -123,6 +123,11 @@ inline ScalarType elementwise_arithm_op_scalar(const ScalarType &a, const Scalar res = (a - b) * (a - b); break; } + case ArithmeticOperation::DIV: + { + res = a / b; + break; + } default: ARM_COMPUTE_ERROR("NOT_SUPPORTED!"); } @@ -154,7 +159,6 @@ inline VectorType elementwise_arithm_op(const VectorType &a, const VectorType &b res = wrapper::vmul(tmp, tmp); break; } - default: ARM_COMPUTE_ERROR("NOT_SUPPORTED!"); } @@ -162,6 +166,20 @@ inline VectorType elementwise_arithm_op(const VectorType &a, const VectorType &b return res; } +template <> +inline float32x4_t elementwise_arithm_op<ArithmeticOperation::DIV, float32x4_t>(const float32x4_t &a, const float32x4_t &b) +{ + return wrapper::vdiv(a, b); +} + +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +template <> +inline float16x8_t elementwise_arithm_op<ArithmeticOperation::DIV, float16x8_t>(const float16x8_t &a, const float16x8_t &b) +{ + return wrapper::vdiv(a, b); +} +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + template <ArithmeticOperation op> inline float32x4x4_t elementwise_arithm_op(const float32x4x4_t &a, const float32x4x4_t &b) { @@ -833,6 +851,28 @@ Status NEArithmeticOperationKernel::validate(ArithmeticOperation op, const ITens return Status{}; } +/** The division operator */ + +void NEDivisionOperationKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output) +{ + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1->info(), *input2->info(), *output->info())); + configure_common(input1, input2, output); + _function = configure_arithm_func<ArithmeticOperation::DIV>(input1, input2, output); +} + +Status NEDivisionOperationKernel::validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output) +{ + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::F16, DataType::F32); + return NEArithmeticOperationKernel::validate_arguments(input1, input2, output); +} + +Status NEDivisionOperationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output) +{ + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output); + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output)); + return Status{}; +} + /** Comparison operators (equal, not equal, less than, greater than, less than or equal, greater than or equal) */ void NEComparisonOperationKernel::configure(ComparisonOperation op, const ITensor *input1, const ITensor *input2, ITensor *output) diff --git a/src/runtime/NEON/functions/NEElementwiseOperators.cpp b/src/runtime/NEON/functions/NEElementwiseOperators.cpp index 711e99ea77..4c570685bc 100644 --- a/src/runtime/NEON/functions/NEElementwiseOperators.cpp +++ b/src/runtime/NEON/functions/NEElementwiseOperators.cpp @@ -67,6 +67,18 @@ Status NEElementwiseSquaredDiff::validate(const ITensorInfo *input1, const ITens return NEArithmeticOperationKernel::validate(ArithmeticOperation::SQUARED_DIFF, input1, input2, output); } +void NEElementwiseDivision::configure(ITensor *input1, ITensor *input2, ITensor *output) +{ + auto k = arm_compute::support::cpp14::make_unique<NEDivisionOperationKernel>(); + k->configure(input1, input2, output); + _kernel = std::move(k); +} + +Status NEElementwiseDivision::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output) +{ + return NEDivisionOperationKernel::validate(input1, input2, output); +} + template <ComparisonOperation COP> void NEElementwiseComparisonStatic<COP>::configure(ITensor *input1, ITensor *input2, ITensor *output) { |