aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp23
1 files changed, 16 insertions, 7 deletions
diff --git a/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp b/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp
index 92371936fa..b2700d9cd6 100644
--- a/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp
+++ b/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp
@@ -669,9 +669,12 @@ inline Status validate_arguments(const ITensorInfo &input1, const ITensorInfo &i
{
ARM_COMPUTE_UNUSED(policy);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input1);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&output, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::S32, DataType::F16,
+ DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::S32, DataType::F16,
+ DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&output, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::S32, DataType::F16,
+ DataType::F32);
const TensorShape out_shape = TensorShape::broadcast_shape(input1.tensor_shape(), input2.tensor_shape());
ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
@@ -685,15 +688,16 @@ inline Status validate_arguments(const ITensorInfo &input1, const ITensorInfo &i
&& !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::S16)
&& !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::U8)
&& !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::S16)
+ && !(input1.data_type() == DataType::S32 && input2.data_type() == DataType::S32)
&& !(input1.data_type() == DataType::F32 && input2.data_type() == DataType::F32)
&& !(input1.data_type() == DataType::F16 && input2.data_type() == DataType::F16),
"You called subtract with the wrong image formats");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- input1.data_type() == DataType::QASYMM8_SIGNED && input2.data_type() == DataType::QASYMM8_SIGNED && policy == ConvertPolicy::WRAP
- && input1.data_type() == DataType::QASYMM8 && input2.data_type() == DataType::QASYMM8 && policy == ConvertPolicy::WRAP
- && input1.data_type() == DataType::QSYMM16 && input2.data_type() == DataType::QSYMM16 && policy == ConvertPolicy::WRAP,
- "Convert policy cannot be WRAP if datatype is QASYMM8 or QASYMM8_SIGNED");
+ (input1.data_type() == DataType::QASYMM8_SIGNED && input2.data_type() == DataType::QASYMM8_SIGNED && policy == ConvertPolicy::WRAP)
+ || (input1.data_type() == DataType::QASYMM8 && input2.data_type() == DataType::QASYMM8 && policy == ConvertPolicy::WRAP)
+ || (input1.data_type() == DataType::QSYMM16 && input2.data_type() == DataType::QSYMM16 && policy == ConvertPolicy::WRAP),
+ "Convert policy cannot be WRAP if datatype is quantized");
// Validate in case of configured output
if(output.total_size() > 0)
@@ -707,6 +711,7 @@ inline Status validate_arguments(const ITensorInfo &input1, const ITensorInfo &i
&& !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::S16 && output.data_type() == DataType::S16)
&& !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::U8 && output.data_type() == DataType::S16)
&& !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::S16 && output.data_type() == DataType::S16)
+ && !(input1.data_type() == DataType::S32 && input2.data_type() == DataType::S32 && output.data_type() == DataType::S32)
&& !(input1.data_type() == DataType::F32 && input2.data_type() == DataType::F32 && output.data_type() == DataType::F32)
&& !(input1.data_type() == DataType::F16 && input2.data_type() == DataType::F16 && output.data_type() == DataType::F16),
"You called subtract with the wrong image formats");
@@ -776,6 +781,10 @@ void NEArithmeticSubtractionKernel::configure(const ITensorInfo *input1, const I
_func = &sub_QSYMM16_QSYMM16_QSYMM16;
set_data_type_if_unknown(*output, DataType::QSYMM16);
break;
+ case DataType::S32:
+ _func = &sub_same<int32_t>;
+ set_format_if_unknown(*output, Format::S32);
+ break;
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
_func = &sub_same<float16_t>;