aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2017-07-11 15:00:52 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-09-17 14:16:42 +0100
commit81f0d15d6840a0ae8ef571114555a26da74c4a43 (patch)
treea9eeda0a2b69961cd6a51d74e039bbed26a9b436 /src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
parentf70256bd46f03090281581c152bd17b4a50febcd (diff)
downloadComputeLibrary-81f0d15d6840a0ae8ef571114555a26da74c4a43.tar.gz
COMPMID-444: Add support for QS8/QS16 NEON Arithmetic Add/Sub/Mul.
Change-Id: Ia482498688ca1884272b5062e3415e736e03d36f Reviewed-on: http://mpd-gerrit.cambridge.arm.com/80448 Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp64
1 files changed, 57 insertions, 7 deletions
diff --git a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
index 83d6d8218e..150db39695 100644
--- a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
+++ b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
@@ -148,6 +148,46 @@ void mul_QS8_QS8_QS8_n(const void *__restrict input1_ptr, const void *__restrict
}
template <bool is_scale255, bool is_sat>
+void mul_QS16_QS16_QS16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int n, int fixed_point_position)
+{
+ // n is the exponent of the scaling factor, that is scale = 1/2^n. Currently, we only support scaling factor equal to 1 => n = 0.
+ ARM_COMPUTE_ERROR_ON_MSG(n != 0, "Scaling factor different than 1 not supported for 16-bit fixed-point pixel-wise multiplication");
+ ARM_COMPUTE_UNUSED(n);
+
+ const qint16x8x2_t ta1 = vld2q_qs16(static_cast<const qint16_t *__restrict>(input1_ptr));
+ const qint16x8x2_t ta2 = vld2q_qs16(static_cast<const qint16_t *__restrict>(input2_ptr));
+
+ if(is_sat)
+ {
+ const qint16x8x2_t res =
+ {
+ {
+ // First 8 elements
+ vqmulq_qs16(ta1.val[0], ta2.val[0], fixed_point_position),
+ // Second 8 elements
+ vqmulq_qs16(ta1.val[1], ta2.val[1], fixed_point_position)
+ }
+ };
+
+ vst2q_s16(static_cast<qint16_t *__restrict>(output_ptr), res);
+ }
+ else
+ {
+ const qint16x8x2_t res =
+ {
+ {
+ // First 8 elements
+ vmulq_qs16(ta1.val[0], ta2.val[0], fixed_point_position),
+ // Second 8 elements
+ vmulq_qs16(ta1.val[1], ta2.val[1], fixed_point_position)
+ }
+ };
+
+ vst2q_s16(static_cast<qint16_t *__restrict>(output_ptr), res);
+ }
+}
+
+template <bool is_scale255, bool is_sat>
inline int16x8_t mul_S16_S16_S16_n_loop(const int16x8_t &input1, const int16x8_t &input2, int n)
{
int32x4_t tmp1_high = vmovl_s16(vget_high_s16(input1));
@@ -389,16 +429,15 @@ void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITe
}
ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
ARM_COMPUTE_ERROR_ON_MSG(output->info()->data_type() == DataType::U8 && (input1->info()->data_type() != DataType::U8 || input2->info()->data_type() != DataType::U8),
"Output can only be U8 if both inputs are U8");
- if(input1->info()->data_type() == DataType::QS8)
+ if(is_data_type_fixed_point(input1->info()->data_type()) || is_data_type_fixed_point(input2->info()->data_type()) || is_data_type_fixed_point(output->info()->data_type()))
{
- // All data types must be QS8
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2, output);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input1, input2, output);
+ // Check that all data types are the same and all fixed-point positions are the same
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2, output);
}
_input1 = input1;
@@ -513,6 +552,17 @@ void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITe
_func_q_int = is_sat ? &mul_QS8_QS8_QS8_n<false, true> : &mul_QS8_QS8_QS8_n<false, false>;
}
}
+ else if(DataType::QS16 == dt_input1 && DataType::QS16 == dt_input2 && DataType::QS16 == dt_output)
+ {
+ if(is_scale_255)
+ {
+ _func_q_int = is_sat ? &mul_QS16_QS16_QS16_n<true, true> : &mul_QS16_QS16_QS16_n<true, false>;
+ }
+ else
+ {
+ _func_q_int = is_sat ? &mul_QS16_QS16_QS16_n<false, true> : &mul_QS16_QS16_QS16_n<false, false>;
+ }
+ }
else if(DataType::F16 == dt_input1 && DataType::F16 == dt_input2 && DataType::F16 == dt_output)
{
_func_float = &mul_F16_F16_F16_n<false, false>;