aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--arm_compute/core/NEON/kernels/NEArithmeticSubtractionKernel.h24
-rw-r--r--src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp52
-rw-r--r--tests/validation/NEON/ArithmeticSubtraction.cpp18
3 files changed, 85 insertions, 9 deletions
diff --git a/arm_compute/core/NEON/kernels/NEArithmeticSubtractionKernel.h b/arm_compute/core/NEON/kernels/NEArithmeticSubtractionKernel.h
index 6f88d2757a..a11bf44458 100644
--- a/arm_compute/core/NEON/kernels/NEArithmeticSubtractionKernel.h
+++ b/arm_compute/core/NEON/kernels/NEArithmeticSubtractionKernel.h
@@ -50,9 +50,21 @@ public:
/** Initialise the kernel's input, output and border mode.
*
- * @param[in] input1 An input tensor. Data types supported: U8/QS8/QS16/S16/F32
- * @param[in] input2 An input tensor. Data types supported: U8, QS8 (only if @p input1 is QS8),QS16 (only if @p input1 is QS16), S16/F32 (only if @p input1 is F32).
- * @param[out] output The output tensor. Data types supported: U8 (Only if both inputs are U8), QS8 (only if both inputs are QS8), QS16 (only if both inputs are QS16), S16/F32 (only if both inputs are F32).
+ * Valid configurations (Input1,Input2) -> Output :
+ *
+ * - (U8,U8) -> U8
+ * - (QS8,QS8) -> QS8
+ * - (U8,U8) -> S16
+ * - (S16,U8) -> S16
+ * - (U8,S16) -> S16
+ * - (S16,S16) -> S16
+ * - (QS16,QS16) -> QS16
+ * - (F16,F16) -> F16
+ * - (F32,F32) -> F32
+ *
+ * @param[in] input1 An input tensor. Data types supported: U8/QS8/QS16/S16/F16/F32
+ * @param[in] input2 An input tensor. Data types supported: U8/QS8/QS16/S16/F16/F32
+ * @param[out] output The output tensor. Data types supported: U8/QS8/QS16/S16/F16/F32.
* @param[in] policy Overflow policy.
*/
void configure(const ITensor *input1, const ITensor *input2, ITensor *output, ConvertPolicy policy);
@@ -63,9 +75,9 @@ public:
private:
/** Common signature for all the specialised sub functions
*
- * @param[in] input1 An input tensor. Data types supported: U8/S16/F32
- * @param[in] input2 An input tensor. Data types supported: U8/S16/F32 (only if @p input1 is F32).
- * @param[out] output The output tensor. Data types supported: U8 (Only if both inputs are U8), S16/F32 (only if both inputs are F32).
+ * @param[in] input1 An input tensor. Data types supported: U8/QS8/QS16/S16/F16/F32
+ * @param[in] input2 An input tensor. Data types supported: U8/QS8/QS16/S16/F16/F32
+ * @param[out] output The output tensor. Data types supported: U8/QS8/QS16/S16/F16/F32.
* @param[in] window Region on which to execute the kernel.
*/
using SubFunction = void(const ITensor *input1, const ITensor *input2, ITensor *output, const Window &window);
diff --git a/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp b/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp
index cac2a6bd05..be8574317b 100644
--- a/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp
+++ b/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp
@@ -157,6 +157,45 @@ void sub_saturate_S16_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *o
input1, input2, output);
}
+#ifdef ARM_COMPUTE_ENABLE_FP16
+inline float16x8x2_t vsub2q_f16(const float16x8x2_t &a, const float16x8x2_t &b)
+{
+ const float16x8x2_t res =
+ {
+ {
+ vsubq_f16(a.val[0], b.val[0]),
+ vsubq_f16(a.val[1], b.val[1])
+ }
+ };
+
+ return res;
+}
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+
+void sub_F16_F16_F16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
+{
+#ifdef ARM_COMPUTE_ENABLE_FP16
+ Iterator input1(in1, window);
+ Iterator input2(in2, window);
+ Iterator output(out, window);
+
+ execute_window_loop(window, [&](const Coordinates & id)
+ {
+ const float16x8x2_t a = vld2q_f16(reinterpret_cast<const float16_t *>(input1.ptr()));
+ const float16x8x2_t b = vld2q_f16(reinterpret_cast<const float16_t *>(input2.ptr()));
+
+ vst2q_f16(reinterpret_cast<float16_t *>(output.ptr()), vsub2q_f16(a, b));
+ },
+ input1, input2, output);
+#else /* ARM_COMPUTE_ENABLE_FP16 */
+ ARM_COMPUTE_UNUSED(in1);
+ ARM_COMPUTE_UNUSED(in2);
+ ARM_COMPUTE_UNUSED(out);
+ ARM_COMPUTE_UNUSED(window);
+ ARM_COMPUTE_ERROR("Not supported, recompile the library with arch=arm64-v8.2-a");
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+}
+
void sub_F32_F32_F32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
Iterator input1(in1, window);
@@ -328,6 +367,10 @@ void NEArithmeticSubtractionKernel::configure(const ITensor *input1, const ITens
{
set_format_if_unknown(*output->info(), Format::S16);
}
+ else if(input1->info()->data_type() == DataType::F16 || input2->info()->data_type() == DataType::F16)
+ {
+ set_format_if_unknown(*output->info(), Format::F16);
+ }
else if(input1->info()->data_type() == DataType::F32 || input2->info()->data_type() == DataType::F32)
{
set_format_if_unknown(*output->info(), Format::F32);
@@ -335,9 +378,9 @@ void NEArithmeticSubtractionKernel::configure(const ITensor *input1, const ITens
}
ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
ARM_COMPUTE_ERROR_ON_MSG(output->info()->data_type() == DataType::U8 && (input1->info()->data_type() != DataType::U8 || input2->info()->data_type() != DataType::U8),
"Output can only be U8 if both inputs are U8");
if(is_data_type_fixed_point(input1->info()->data_type()) || is_data_type_fixed_point(input2->info()->data_type()) || is_data_type_fixed_point(output->info()->data_type()))
@@ -364,6 +407,9 @@ void NEArithmeticSubtractionKernel::configure(const ITensor *input1, const ITens
{ "sub_saturate_S16_S16_S16", &sub_saturate_S16_S16_S16 },
{ "sub_wrap_F32_F32_F32", &sub_F32_F32_F32 },
{ "sub_saturate_F32_F32_F32", &sub_F32_F32_F32 },
+ { "sub_wrap_F16_F16_F16", &sub_F16_F16_F16 },
+ { "sub_saturate_F16_F16_F16", &sub_F16_F16_F16 },
+
};
_input1 = input1;
diff --git a/tests/validation/NEON/ArithmeticSubtraction.cpp b/tests/validation/NEON/ArithmeticSubtraction.cpp
index 54cd9f04ba..23a320a84d 100644
--- a/tests/validation/NEON/ArithmeticSubtraction.cpp
+++ b/tests/validation/NEON/ArithmeticSubtraction.cpp
@@ -200,6 +200,7 @@ BOOST_DATA_TEST_CASE(RunSmall, SmallShapes() * ConvertPolicies() * boost::unit_t
// Validate output
validate(NEAccessor(dst), ref_dst);
}
+
BOOST_TEST_DECORATOR(*boost::unit_test::label("nightly"))
BOOST_DATA_TEST_CASE(RunLarge, LargeShapes() * ConvertPolicies() * boost::unit_test::data::xrange(1, 7),
shape, policy, fixed_point_position)
@@ -245,6 +246,23 @@ BOOST_DATA_TEST_CASE(RunLarge, LargeShapes() * ConvertPolicies() * boost::unit_t
BOOST_AUTO_TEST_SUITE_END()
BOOST_AUTO_TEST_SUITE_END()
+#ifdef ARM_COMPUTE_ENABLE_FP16
+BOOST_AUTO_TEST_SUITE(Float16)
+BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit"))
+BOOST_DATA_TEST_CASE(RunSmall, SmallShapes(), shape)
+{
+ // Compute function
+ Tensor dst = compute_arithmetic_subtraction(shape, DataType::F16, DataType::F16, DataType::F16, ConvertPolicy::WRAP);
+
+ // Compute reference
+ RawTensor ref_dst = Reference::compute_reference_arithmetic_subtraction(shape, DataType::F16, DataType::F16, DataType::F16, ConvertPolicy::WRAP);
+
+ // Validate output
+ validate(NEAccessor(dst), ref_dst);
+}
+BOOST_AUTO_TEST_SUITE_END()
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+
BOOST_AUTO_TEST_SUITE(Float)
BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit") * boost::unit_test::label("nightly"))
BOOST_DATA_TEST_CASE(Configuration, (SmallShapes() + LargeShapes()) * boost::unit_test::data::make({ ConvertPolicy::SATURATE, ConvertPolicy::WRAP }),