diff options
author | SiCong Li <sicong.li@arm.com> | 2020-08-27 10:17:10 +0100 |
---|---|---|
committer | SiCong Li <sicong.li@arm.com> | 2020-09-07 10:57:52 +0000 |
commit | 903f8cca78502a9e3835e6ec42caa1f816274600 (patch) | |
tree | e8e104990b0b718550797bfe7c7c67c2a722e849 /src/core | |
parent | 2d2213920ba5ab95052a557dd20594a6ccb7d562 (diff) | |
download | ComputeLibrary-903f8cca78502a9e3835e6ec42caa1f816274600.tar.gz |
COMPMID-3580 Add S32 support to NEArithmeticSubtraction
* Fix convert policy validate logics and add missing validate test
* Add S32 support to NEArithmeticSubtraction and NEArithmeticSubtractionKernel
* Add S32 validation tests
Change-Id: I1b6cb15b024613c202fe9f17747a83da43a5ddcf
Signed-off-by: SiCong Li <sicong.li@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3908
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Diffstat (limited to 'src/core')
-rw-r--r-- | src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp | 23 |
1 files changed, 16 insertions, 7 deletions
diff --git a/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp b/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp index 92371936fa..b2700d9cd6 100644 --- a/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp +++ b/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp @@ -669,9 +669,12 @@ inline Status validate_arguments(const ITensorInfo &input1, const ITensorInfo &i { ARM_COMPUTE_UNUSED(policy); ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input1); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&output, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::S32, DataType::F16, + DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::S32, DataType::F16, + DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&output, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::S32, DataType::F16, + DataType::F32); const TensorShape out_shape = TensorShape::broadcast_shape(input1.tensor_shape(), input2.tensor_shape()); ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible"); @@ -685,15 +688,16 @@ inline Status validate_arguments(const ITensorInfo &input1, const ITensorInfo &i && !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::S16) && !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::U8) && !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::S16) + && !(input1.data_type() == DataType::S32 && input2.data_type() == DataType::S32) && !(input1.data_type() == DataType::F32 && input2.data_type() == DataType::F32) && !(input1.data_type() == DataType::F16 && input2.data_type() == DataType::F16), "You called subtract with the wrong image formats"); ARM_COMPUTE_RETURN_ERROR_ON_MSG( - input1.data_type() == DataType::QASYMM8_SIGNED && input2.data_type() == DataType::QASYMM8_SIGNED && policy == ConvertPolicy::WRAP - && input1.data_type() == DataType::QASYMM8 && input2.data_type() == DataType::QASYMM8 && policy == ConvertPolicy::WRAP - && input1.data_type() == DataType::QSYMM16 && input2.data_type() == DataType::QSYMM16 && policy == ConvertPolicy::WRAP, - "Convert policy cannot be WRAP if datatype is QASYMM8 or QASYMM8_SIGNED"); + (input1.data_type() == DataType::QASYMM8_SIGNED && input2.data_type() == DataType::QASYMM8_SIGNED && policy == ConvertPolicy::WRAP) + || (input1.data_type() == DataType::QASYMM8 && input2.data_type() == DataType::QASYMM8 && policy == ConvertPolicy::WRAP) + || (input1.data_type() == DataType::QSYMM16 && input2.data_type() == DataType::QSYMM16 && policy == ConvertPolicy::WRAP), + "Convert policy cannot be WRAP if datatype is quantized"); // Validate in case of configured output if(output.total_size() > 0) @@ -707,6 +711,7 @@ inline Status validate_arguments(const ITensorInfo &input1, const ITensorInfo &i && !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::S16 && output.data_type() == DataType::S16) && !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::U8 && output.data_type() == DataType::S16) && !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::S16 && output.data_type() == DataType::S16) + && !(input1.data_type() == DataType::S32 && input2.data_type() == DataType::S32 && output.data_type() == DataType::S32) && !(input1.data_type() == DataType::F32 && input2.data_type() == DataType::F32 && output.data_type() == DataType::F32) && !(input1.data_type() == DataType::F16 && input2.data_type() == DataType::F16 && output.data_type() == DataType::F16), "You called subtract with the wrong image formats"); @@ -776,6 +781,10 @@ void NEArithmeticSubtractionKernel::configure(const ITensorInfo *input1, const I _func = &sub_QSYMM16_QSYMM16_QSYMM16; set_data_type_if_unknown(*output, DataType::QSYMM16); break; + case DataType::S32: + _func = &sub_same<int32_t>; + set_format_if_unknown(*output, Format::S32); + break; #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: _func = &sub_same<float16_t>; |