aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEElementwiseOperationKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NEElementwiseOperationKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEElementwiseOperationKernel.cpp16
1 files changed, 15 insertions, 1 deletions
diff --git a/src/core/NEON/kernels/NEElementwiseOperationKernel.cpp b/src/core/NEON/kernels/NEElementwiseOperationKernel.cpp
index db4f5923bc..da53a523e6 100644
--- a/src/core/NEON/kernels/NEElementwiseOperationKernel.cpp
+++ b/src/core/NEON/kernels/NEElementwiseOperationKernel.cpp
@@ -142,6 +142,14 @@ inline ScalarType elementwise_arithm_op_scalar(const ScalarType &a, const Scalar
case ArithmeticOperation::DIV:
{
res = a / b;
+ if(std::is_integral<ScalarType>::value)
+ {
+ res = (b == 0) ? 0 : res;
+ if(static_cast<int32_t>(a) % static_cast<int32_t>(b) != 0 && ((a < 0) != (b < 0)))
+ {
+ --res;
+ }
+ }
break;
}
case ArithmeticOperation::POWER:
@@ -208,6 +216,12 @@ inline typename VectorType::type elementwise_arithm_op(const typename VectorType
}
template <>
+inline int32x4_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<int32_t, 4>>(const int32x4_t &a, const int32x4_t &b)
+{
+ return vcvtq_s32_f32(vfloorq_f32(wrapper::vdiv(vcvtq_f32_s32(a), vcvtq_f32_s32(b))));
+}
+
+template <>
inline float32x4_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<float, 4>>(const float32x4_t &a, const float32x4_t &b)
{
return wrapper::vdiv(a, b);
@@ -1259,7 +1273,7 @@ void NEDivisionOperationKernel::configure(const ITensorInfo *input1, const ITens
Status NEDivisionOperationKernel::validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::S32, DataType::F16, DataType::F32);
return NEArithmeticOperationKernel::validate_arguments(input1, input2, output);
}