aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSiCong Li <sicong.li@arm.com>2020-08-27 10:17:10 +0100
committerSiCong Li <sicong.li@arm.com>2020-09-07 10:57:52 +0000
commit903f8cca78502a9e3835e6ec42caa1f816274600 (patch)
treee8e104990b0b718550797bfe7c7c67c2a722e849
parent2d2213920ba5ab95052a557dd20594a6ccb7d562 (diff)
downloadComputeLibrary-903f8cca78502a9e3835e6ec42caa1f816274600.tar.gz
COMPMID-3580 Add S32 support to NEArithmeticSubtraction
* Fix convert policy validate logics and add missing validate test * Add S32 support to NEArithmeticSubtraction and NEArithmeticSubtractionKernel * Add S32 validation tests Change-Id: I1b6cb15b024613c202fe9f17747a83da43a5ddcf Signed-off-by: SiCong Li <sicong.li@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3908 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
-rw-r--r--arm_compute/core/NEON/kernels/NEArithmeticSubtractionKernel.h22
-rw-r--r--arm_compute/core/NEON/wrapper/scalar/sub.h7
-rw-r--r--arm_compute/runtime/NEON/functions/NEArithmeticSubtraction.h30
-rw-r--r--docs/00_introduction.dox3
-rw-r--r--src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp23
-rw-r--r--tests/validation/NEON/ArithmeticSubtraction.cpp24
6 files changed, 77 insertions, 32 deletions
diff --git a/arm_compute/core/NEON/kernels/NEArithmeticSubtractionKernel.h b/arm_compute/core/NEON/kernels/NEArithmeticSubtractionKernel.h
index e3a41a2b1c..7d00d1f7d0 100644
--- a/arm_compute/core/NEON/kernels/NEArithmeticSubtractionKernel.h
+++ b/arm_compute/core/NEON/kernels/NEArithmeticSubtractionKernel.h
@@ -63,12 +63,13 @@ public:
* - (S16,U8) -> S16
* - (U8,S16) -> S16
* - (S16,S16) -> S16
+ * - (S32,S32) -> S32
* - (F16,F16) -> F16
* - (F32,F32) -> F32
*
- * @param[in] input1 An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/F16/F32
- * @param[in] input2 An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/F16/F32
- * @param[out] output The output tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/F16/F32.
+ * @param[in] input1 An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/S32/F16/F32
+ * @param[in] input2 An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/S32/F16/F32
+ * @param[out] output The output tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/S32/F16/F32.
* @param[in] policy Overflow policy. Convert policy cannot be WRAP if datatype is quantized.
*/
void configure(const ITensorInfo *input1, const ITensorInfo *input2, ITensorInfo *output, ConvertPolicy policy);
@@ -83,14 +84,13 @@ public:
* - (S16,U8) -> S16
* - (U8,S16) -> S16
* - (S16,S16) -> S16
+ * - (S32,S32) -> S32
* - (F16,F16) -> F16
* - (F32,F32) -> F32
*
- * @note Convert policy cannot be WRAP if datatype is QASYMM8
- *
- * @param[in] input1 An input tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/F16/F32
- * @param[in] input2 An input tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/F16/F32
- * @param[in] output The output tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/F16/F32.
+ * @param[in] input1 An input tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/S32/F16/F32
+ * @param[in] input2 An input tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/S32/F16/F32
+ * @param[in] output The output tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/S32/F16/F32.
* @param[in] policy Policy to use to handle overflow. Convert policy cannot be WRAP if datatype is quantized.
*
* @return a status
@@ -103,9 +103,9 @@ public:
private:
/** Common signature for all the specialised sub functions
*
- * @param[in] input1 An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/F16/F32
- * @param[in] input2 An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/F16/F32
- * @param[out] output The output tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/F16/F32.
+ * @param[in] input1 An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/S32/F16/F32
+ * @param[in] input2 An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/S32/F16/F32
+ * @param[out] output The output tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/S32/F16/F32.
* @param[in] window Region on which to execute the kernel.
* @param[in] is_sat Flag to indicate if the policy is SATURATE.
*/
diff --git a/arm_compute/core/NEON/wrapper/scalar/sub.h b/arm_compute/core/NEON/wrapper/scalar/sub.h
index 9abda26224..1fe51d75fc 100644
--- a/arm_compute/core/NEON/wrapper/scalar/sub.h
+++ b/arm_compute/core/NEON/wrapper/scalar/sub.h
@@ -44,6 +44,13 @@ inline int16_t sub_sat(const int16_t &a, const int16_t &b)
return vget_lane_s16(vqsub_s16(va, vb), 0);
}
+inline int32_t sub_sat(const int32_t &a, const int32_t &b)
+{
+ const int32x2_t va = { a, 0 };
+ const int32x2_t vb = { b, 0 };
+ return vget_lane_s32(vqsub_s32(va, vb), 0);
+}
+
inline float sub_sat(const float &a, const float &b)
{
// No notion of saturation exists in floating point
diff --git a/arm_compute/runtime/NEON/functions/NEArithmeticSubtraction.h b/arm_compute/runtime/NEON/functions/NEArithmeticSubtraction.h
index a38335c59b..5d2475b3a4 100644
--- a/arm_compute/runtime/NEON/functions/NEArithmeticSubtraction.h
+++ b/arm_compute/runtime/NEON/functions/NEArithmeticSubtraction.h
@@ -36,7 +36,7 @@ namespace experimental
{
/** Basic function to run @ref NEArithmeticSubtractionKernel
*
- * @note The tensor data type for the inputs must be U8/QASYMM8/S16/F16/F32.
+ * @note The tensor data type for the inputs must be U8/QASYMM8/S16/S32/F16/F32.
* @note The function performs an arithmetic subtraction between two tensors.
*
* This function calls the following kernels:
@@ -56,12 +56,13 @@ public:
* - (S16,U8) -> S16
* - (U8,S16) -> S16
* - (S16,S16) -> S16
+ * - (S32,S32) -> S32
* - (F16,F16) -> F16
* - (F32,F32) -> F32
*
- * @param[in] input1 First tensor input info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/F16/F32
- * @param[in] input2 Second tensor input info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/F16/F32
- * @param[out] output Output tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/F16/F32
+ * @param[in] input1 First tensor input info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/S32/F16/F32
+ * @param[in] input2 Second tensor input info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/S32/F16/F32
+ * @param[out] output Output tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/S32/F16/F32
* @param[in] policy Policy to use to handle overflow. Convert policy cannot be WRAP if datatype is quantized.
* @param[in] act_info (Optional) Activation layer information in case of a fused activation. Currently not supported.
*/
@@ -77,12 +78,13 @@ public:
* - (S16,U8) -> S16
* - (U8,S16) -> S16
* - (S16,S16) -> S16
+ * - (S32,S32) -> S32
* - (F16,F16) -> F16
* - (F32,F32) -> F32
*
- * @param[in] input1 First tensor input. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/F16/F32
- * @param[in] input2 Second tensor input. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/F16/F32
- * @param[in] output Output tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/F16/F32
+ * @param[in] input1 First tensor input. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/S32/F16/F32
+ * @param[in] input2 Second tensor input. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/S32/F16/F32
+ * @param[in] output Output tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/S32/F16/F32
* @param[in] policy Policy to use to handle overflow. Convert policy cannot be WRAP if datatype is quantized.
* @param[in] act_info (Optional) Activation layer information in case of a fused activation. Currently not supported.
*
@@ -94,7 +96,7 @@ public:
/** Basic function to run @ref NEArithmeticSubtractionKernel
*
- * @note The tensor data type for the inputs must be U8/QASYMM8/S16/F16/F32.
+ * @note The tensor data type for the inputs must be U8/QASYMM8/S16/S32/F16/F32.
* @note The function performs an arithmetic subtraction between two tensors.
*
* This function calls the following kernels:
@@ -117,18 +119,18 @@ public:
NEArithmeticSubtraction &operator=(NEArithmeticSubtraction &&);
/** Initialise the kernel's inputs, output and conversion policy.
*
- * @param[in] input1 First tensor input. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/F16/F32
- * @param[in] input2 Second tensor input. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/F16/F32
- * @param[out] output Output tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/F16/F32
+ * @param[in] input1 First tensor input. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/S32/F16/F32
+ * @param[in] input2 Second tensor input. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/S32/F16/F32
+ * @param[out] output Output tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/S32/F16/F32
* @param[in] policy Policy to use to handle overflow. Convert policy cannot be WRAP if datatype is quantized.
* @param[in] act_info (Optional) Activation layer information in case of a fused activation. Currently not supported.
*/
void configure(const ITensor *input1, const ITensor *input2, ITensor *output, ConvertPolicy policy, const ActivationLayerInfo &act_info = ActivationLayerInfo());
/** Static function to check if given info will lead to a valid configuration of @ref NEArithmeticSubtraction
*
- * @param[in] input1 First tensor input. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/F16/F32
- * @param[in] input2 Second tensor input. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/F16/F32
- * @param[in] output Output tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/F16/F32
+ * @param[in] input1 First tensor input. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/S32/F16/F32
+ * @param[in] input2 Second tensor input. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/S32/F16/F32
+ * @param[in] output Output tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/S32/F16/F32
* @param[in] policy Policy to use to handle overflow. Convert policy cannot be WRAP if datatype is quantized.
* @param[in] act_info (Optional) Activation layer information in case of a fused activation. Currently not supported.
*
diff --git a/docs/00_introduction.dox b/docs/00_introduction.dox
index 538fb3632b..9db7e57cbe 100644
--- a/docs/00_introduction.dox
+++ b/docs/00_introduction.dox
@@ -238,6 +238,9 @@ If there is more than one release in a month then an extra sequential number is
@subsection S2_2_changelog Changelog
v20.11 Public major release
+ - Added new data type S32 support for:
+ - @ref NEArithmeticSubtraction
+ - @ref NEArithmeticSubtractionKernel
- Interface change
- Properly support softmax axis to have the same meaning as other major frameworks. That is, axis now defines the dimension
on which Softmax/Logsoftmax is performed. E.g. for input of shape 4x5x6 and axis=1, softmax will be applied to 4x6=24 vectors of size 5.
diff --git a/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp b/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp
index 92371936fa..b2700d9cd6 100644
--- a/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp
+++ b/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp
@@ -669,9 +669,12 @@ inline Status validate_arguments(const ITensorInfo &input1, const ITensorInfo &i
{
ARM_COMPUTE_UNUSED(policy);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input1);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&output, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::S32, DataType::F16,
+ DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::S32, DataType::F16,
+ DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&output, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::S32, DataType::F16,
+ DataType::F32);
const TensorShape out_shape = TensorShape::broadcast_shape(input1.tensor_shape(), input2.tensor_shape());
ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
@@ -685,15 +688,16 @@ inline Status validate_arguments(const ITensorInfo &input1, const ITensorInfo &i
&& !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::S16)
&& !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::U8)
&& !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::S16)
+ && !(input1.data_type() == DataType::S32 && input2.data_type() == DataType::S32)
&& !(input1.data_type() == DataType::F32 && input2.data_type() == DataType::F32)
&& !(input1.data_type() == DataType::F16 && input2.data_type() == DataType::F16),
"You called subtract with the wrong image formats");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- input1.data_type() == DataType::QASYMM8_SIGNED && input2.data_type() == DataType::QASYMM8_SIGNED && policy == ConvertPolicy::WRAP
- && input1.data_type() == DataType::QASYMM8 && input2.data_type() == DataType::QASYMM8 && policy == ConvertPolicy::WRAP
- && input1.data_type() == DataType::QSYMM16 && input2.data_type() == DataType::QSYMM16 && policy == ConvertPolicy::WRAP,
- "Convert policy cannot be WRAP if datatype is QASYMM8 or QASYMM8_SIGNED");
+ (input1.data_type() == DataType::QASYMM8_SIGNED && input2.data_type() == DataType::QASYMM8_SIGNED && policy == ConvertPolicy::WRAP)
+ || (input1.data_type() == DataType::QASYMM8 && input2.data_type() == DataType::QASYMM8 && policy == ConvertPolicy::WRAP)
+ || (input1.data_type() == DataType::QSYMM16 && input2.data_type() == DataType::QSYMM16 && policy == ConvertPolicy::WRAP),
+ "Convert policy cannot be WRAP if datatype is quantized");
// Validate in case of configured output
if(output.total_size() > 0)
@@ -707,6 +711,7 @@ inline Status validate_arguments(const ITensorInfo &input1, const ITensorInfo &i
&& !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::S16 && output.data_type() == DataType::S16)
&& !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::U8 && output.data_type() == DataType::S16)
&& !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::S16 && output.data_type() == DataType::S16)
+ && !(input1.data_type() == DataType::S32 && input2.data_type() == DataType::S32 && output.data_type() == DataType::S32)
&& !(input1.data_type() == DataType::F32 && input2.data_type() == DataType::F32 && output.data_type() == DataType::F32)
&& !(input1.data_type() == DataType::F16 && input2.data_type() == DataType::F16 && output.data_type() == DataType::F16),
"You called subtract with the wrong image formats");
@@ -776,6 +781,10 @@ void NEArithmeticSubtractionKernel::configure(const ITensorInfo *input1, const I
_func = &sub_QSYMM16_QSYMM16_QSYMM16;
set_data_type_if_unknown(*output, DataType::QSYMM16);
break;
+ case DataType::S32:
+ _func = &sub_same<int32_t>;
+ set_format_if_unknown(*output, Format::S32);
+ break;
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
_func = &sub_same<float16_t>;
diff --git a/tests/validation/NEON/ArithmeticSubtraction.cpp b/tests/validation/NEON/ArithmeticSubtraction.cpp
index f468f6d5c6..12fe64c396 100644
--- a/tests/validation/NEON/ArithmeticSubtraction.cpp
+++ b/tests/validation/NEON/ArithmeticSubtraction.cpp
@@ -70,6 +70,10 @@ const auto ArithmeticSubtractionU8Dataset = combine(combine(framework::dataset::
const auto ArithmeticSubtractionS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::U8, DataType::S16 }),
framework::dataset::make("DataType", DataType::S16)),
framework::dataset::make("DataType", DataType::S16));
+
+const auto ArithmeticSubtractionS32Dataset = combine(combine(framework::dataset::make("DataType", DataType::S32),
+ framework::dataset::make("DataType", DataType::S32)),
+ framework::dataset::make("DataType", DataType::S32));
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
const auto ArithmeticSubtractionFP16Dataset = combine(combine(framework::dataset::make("DataType", DataType::F16),
framework::dataset::make("DataType", DataType::F16)),
@@ -120,12 +124,14 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8),
})),
framework::dataset::make("ConvertPolicy",{ ConvertPolicy::WRAP,
ConvertPolicy::SATURATE,
ConvertPolicy::SATURATE,
ConvertPolicy::WRAP,
ConvertPolicy::WRAP,
+ ConvertPolicy::WRAP,
})),
framework::dataset::make("Expected", { true, true, false, false, false, false})),
input1_info, input2_info, output_info, policy, expected)
@@ -270,6 +276,24 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixture<int16_t>, framew
}
TEST_SUITE_END() // S16
+TEST_SUITE(S32)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture<int32_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionS32Dataset),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
+ OutOfPlaceDataSet))
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixture<int32_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionS32Dataset),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
+ OutOfPlaceDataSet))
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+TEST_SUITE_END() // S32
+
TEST_SUITE(Float)
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_SUITE(F16)