From 903f8cca78502a9e3835e6ec42caa1f816274600 Mon Sep 17 00:00:00 2001 From: SiCong Li Date: Thu, 27 Aug 2020 10:17:10 +0100 Subject: COMPMID-3580 Add S32 support to NEArithmeticSubtraction * Fix convert policy validate logics and add missing validate test * Add S32 support to NEArithmeticSubtraction and NEArithmeticSubtractionKernel * Add S32 validation tests Change-Id: I1b6cb15b024613c202fe9f17747a83da43a5ddcf Signed-off-by: SiCong Li Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3908 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio --- .../NEON/kernels/NEArithmeticSubtractionKernel.cpp | 23 +++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) (limited to 'src/core/NEON/kernels') diff --git a/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp b/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp index 92371936fa..b2700d9cd6 100644 --- a/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp +++ b/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp @@ -669,9 +669,12 @@ inline Status validate_arguments(const ITensorInfo &input1, const ITensorInfo &i { ARM_COMPUTE_UNUSED(policy); ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input1); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&output, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::S32, DataType::F16, + DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::S32, DataType::F16, + DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&output, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::S32, DataType::F16, + DataType::F32); const TensorShape out_shape = TensorShape::broadcast_shape(input1.tensor_shape(), input2.tensor_shape()); ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible"); @@ -685,15 +688,16 @@ inline Status validate_arguments(const ITensorInfo &input1, const ITensorInfo &i && !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::S16) && !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::U8) && !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::S16) + && !(input1.data_type() == DataType::S32 && input2.data_type() == DataType::S32) && !(input1.data_type() == DataType::F32 && input2.data_type() == DataType::F32) && !(input1.data_type() == DataType::F16 && input2.data_type() == DataType::F16), "You called subtract with the wrong image formats"); ARM_COMPUTE_RETURN_ERROR_ON_MSG( - input1.data_type() == DataType::QASYMM8_SIGNED && input2.data_type() == DataType::QASYMM8_SIGNED && policy == ConvertPolicy::WRAP - && input1.data_type() == DataType::QASYMM8 && input2.data_type() == DataType::QASYMM8 && policy == ConvertPolicy::WRAP - && input1.data_type() == DataType::QSYMM16 && input2.data_type() == DataType::QSYMM16 && policy == ConvertPolicy::WRAP, - "Convert policy cannot be WRAP if datatype is QASYMM8 or QASYMM8_SIGNED"); + (input1.data_type() == DataType::QASYMM8_SIGNED && input2.data_type() == DataType::QASYMM8_SIGNED && policy == ConvertPolicy::WRAP) + || (input1.data_type() == DataType::QASYMM8 && input2.data_type() == DataType::QASYMM8 && policy == ConvertPolicy::WRAP) + || (input1.data_type() == DataType::QSYMM16 && input2.data_type() == DataType::QSYMM16 && policy == ConvertPolicy::WRAP), + "Convert policy cannot be WRAP if datatype is quantized"); // Validate in case of configured output if(output.total_size() > 0) @@ -707,6 +711,7 @@ inline Status validate_arguments(const ITensorInfo &input1, const ITensorInfo &i && !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::S16 && output.data_type() == DataType::S16) && !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::U8 && output.data_type() == DataType::S16) && !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::S16 && output.data_type() == DataType::S16) + && !(input1.data_type() == DataType::S32 && input2.data_type() == DataType::S32 && output.data_type() == DataType::S32) && !(input1.data_type() == DataType::F32 && input2.data_type() == DataType::F32 && output.data_type() == DataType::F32) && !(input1.data_type() == DataType::F16 && input2.data_type() == DataType::F16 && output.data_type() == DataType::F16), "You called subtract with the wrong image formats"); @@ -776,6 +781,10 @@ void NEArithmeticSubtractionKernel::configure(const ITensorInfo *input1, const I _func = &sub_QSYMM16_QSYMM16_QSYMM16; set_data_type_if_unknown(*output, DataType::QSYMM16); break; + case DataType::S32: + _func = &sub_same; + set_format_if_unknown(*output, Format::S32); + break; #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: _func = &sub_same; -- cgit v1.2.1