aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp154
1 files changed, 4 insertions, 150 deletions
diff --git a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
index 193ca3799c..0ec7e823a1 100644
--- a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
+++ b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
@@ -61,9 +61,9 @@ inline Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *i
ARM_COMPUTE_UNUSED(overflow_policy);
ARM_COMPUTE_UNUSED(rounding_policy);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8),
"Output can only be U8 if both inputs are U8");
@@ -71,14 +71,6 @@ inline Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *i
ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
- if(is_data_type_fixed_point(input1->data_type()) || is_data_type_fixed_point(input2->data_type()) || is_data_type_fixed_point(output->data_type()))
- {
- // Check that all data types are the same and all fixed-point positions are the same
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2, output);
- // Check if scale is representable in fixed-point with the provided settings
- ARM_COMPUTE_RETURN_ERROR_ON_VALUE_NOT_REPRESENTABLE_IN_FIXED_POINT(scale, input1);
- }
-
if(std::abs(scale - scale255_constant) < 0.00001f)
{
ARM_COMPUTE_RETURN_ERROR_ON(rounding_policy != RoundingPolicy::TO_NEAREST_UP && rounding_policy != RoundingPolicy::TO_NEAREST_EVEN);
@@ -120,11 +112,6 @@ inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *inpu
{
set_format_if_unknown(*output, Format::F16);
}
- else if(input1->data_type() == DataType::QS8 && input2->data_type() == DataType::QS8)
- {
- set_data_type_if_unknown(*output, DataType::QS8);
- set_fixed_point_position_if_zero(*output, input1->fixed_point_position());
- }
}
// Configure kernel window
@@ -220,105 +207,6 @@ void mul_U8_U8_U8_n(const void *__restrict input1_ptr, const void *__restrict in
}
template <bool is_scale255, bool is_sat>
-void mul_QS8_QS8_QS8_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int n, int fixed_point_position)
-{
- const auto output = static_cast<qint8_t *__restrict>(output_ptr);
-
- const qint8x16_t ta1 = vld1q_qs8(static_cast<const qint8_t *__restrict>(input1_ptr));
- const qint8x16_t ta2 = vld1q_qs8(static_cast<const qint8_t *__restrict>(input2_ptr));
-
- if(is_scale255)
- {
- qint16x8_t tmp1_high = vmovl_s8(vget_high_s8(ta1));
- qint16x8_t tmp1_low = vmovl_s8(vget_low_s8(ta1));
- const qint16x8_t tmp2_high = vmovl_s8(vget_high_s8(ta2));
- const qint16x8_t tmp2_low = vmovl_s8(vget_low_s8(ta2));
-
- const float32x4x2_t scale255_f32 =
- {
- {
- scale255_constant_f32q,
- scale255_constant_f32q
- }
- };
- const qint16x8_t scale255 = vqcvtq_qs16_f32(scale255_f32, fixed_point_position);
-
- tmp1_high = vmulq_qs16(tmp1_high, tmp2_high, fixed_point_position);
- tmp1_low = vmulq_qs16(tmp1_low, tmp2_low, fixed_point_position);
- tmp1_high = vmulq_qs16(tmp1_high, scale255, fixed_point_position);
- tmp1_low = vmulq_qs16(tmp1_low, scale255, fixed_point_position);
-
- if(is_sat)
- {
- vst1q_qs8(output, vcombine_s8(vqmovn_s16(tmp1_low), vqmovn_s16(tmp1_high)));
- }
- else
- {
- vst1q_qs8(output, vcombine_s8(vmovn_s16(tmp1_low), vmovn_s16(tmp1_high)));
- }
- }
- else
- {
- const qint8x16_t vn = vdupq_n_s8(-n);
- qint8x16_t res = ta2;
-
- if(is_sat)
- {
- res = vqshlq_s8(vqmulq_qs8(ta1, res, fixed_point_position), vn);
- }
- else
- {
- res = vshlq_s8(vmulq_qs8(ta1, res, fixed_point_position), vn);
- }
- vst1q_qs8(output, res);
- }
-}
-
-template <bool is_scale255, bool is_sat>
-void mul_QS16_QS16_QS16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int n, int fixed_point_position)
-{
- const qint16x8x2_t ta1 = vld2q_qs16(static_cast<const qint16_t *__restrict>(input1_ptr));
- qint16x8x2_t res = vld2q_qs16(static_cast<const qint16_t *__restrict>(input2_ptr));
-
- if(is_scale255)
- {
- const float32x4x2_t scale255_f32 =
- {
- {
- scale255_constant_f32q,
- scale255_constant_f32q
- }
- };
- const qint16x8_t scale255 = vqcvtq_qs16_f32(scale255_f32, fixed_point_position);
- if(is_sat)
- {
- res.val[0] = vqmulq_qs16(vqmulq_qs16(ta1.val[0], res.val[0], fixed_point_position), scale255, fixed_point_position);
- res.val[1] = vqmulq_qs16(vqmulq_qs16(ta1.val[1], res.val[1], fixed_point_position), scale255, fixed_point_position);
- }
- else
- {
- res.val[0] = vmulq_qs16(vmulq_qs16(ta1.val[0], res.val[0], fixed_point_position), scale255, fixed_point_position);
- res.val[1] = vmulq_qs16(vmulq_qs16(ta1.val[1], res.val[1], fixed_point_position), scale255, fixed_point_position);
- }
- }
- else
- {
- const qint16x8_t vn = vdupq_n_s16(-n);
- if(is_sat)
- {
- res.val[0] = vqshlq_s16(vqmulq_qs16(ta1.val[0], res.val[0], fixed_point_position), vn);
- res.val[1] = vqshlq_s16(vqmulq_qs16(ta1.val[1], res.val[1], fixed_point_position), vn);
- }
- else
- {
- res.val[0] = vshlq_s16(vmulq_qs16(ta1.val[0], res.val[0], fixed_point_position), vn);
- res.val[1] = vshlq_s16(vmulq_qs16(ta1.val[1], res.val[1], fixed_point_position), vn);
- }
- }
- vst2q_s16(static_cast<qint16_t *__restrict>(output_ptr), res);
-}
-
-template <bool is_scale255, bool is_sat>
inline int16x8_t mul_S16_S16_S16_n_loop(const int16x8_t &input1, const int16x8_t &input2, int n)
{
int32x4_t tmp1_high = vmovl_s16(vget_high_s16(input1));
@@ -529,7 +417,7 @@ void mul_U8_S16_S16_n(const void *__restrict input1_ptr, const void *__restrict
} // namespace
NEPixelWiseMultiplicationKernel::NEPixelWiseMultiplicationKernel()
- : _func_float(nullptr), _func_int(nullptr), _func_q_int(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr), _scale{ 0 }, _scale_exponent{ 0 }
+ : _func_float(nullptr), _func_int(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr), _scale{ 0 }, _scale_exponent{ 0 }
{
}
@@ -550,7 +438,6 @@ void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITe
_scale = scale;
_scale_exponent = 0;
_func_int = nullptr;
- _func_q_int = nullptr;
_func_float = nullptr;
bool is_scale_255 = false;
@@ -630,28 +517,6 @@ void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITe
_func_int = is_sat ? &mul_U8_U8_S16_n<false, true> : &mul_U8_U8_S16_n<false, false>;
}
}
- else if(DataType::QS8 == dt_input1 && DataType::QS8 == dt_input2 && DataType::QS8 == dt_output)
- {
- if(is_scale_255)
- {
- _func_q_int = is_sat ? &mul_QS8_QS8_QS8_n<true, true> : &mul_QS8_QS8_QS8_n<true, false>;
- }
- else
- {
- _func_q_int = is_sat ? &mul_QS8_QS8_QS8_n<false, true> : &mul_QS8_QS8_QS8_n<false, false>;
- }
- }
- else if(DataType::QS16 == dt_input1 && DataType::QS16 == dt_input2 && DataType::QS16 == dt_output)
- {
- if(is_scale_255)
- {
- _func_q_int = is_sat ? &mul_QS16_QS16_QS16_n<true, true> : &mul_QS16_QS16_QS16_n<true, false>;
- }
- else
- {
- _func_q_int = is_sat ? &mul_QS16_QS16_QS16_n<false, true> : &mul_QS16_QS16_QS16_n<false, false>;
- }
- }
else if(DataType::F16 == dt_input1 && DataType::F16 == dt_input2 && DataType::F16 == dt_output)
{
_func_float = &mul_F16_F16_F16_n<false, false>;
@@ -724,17 +589,6 @@ void NEPixelWiseMultiplicationKernel::run(const Window &window, const ThreadInfo
},
input1, input2, output);
}
- else if(_func_q_int != nullptr)
- {
- int fixed_point_position = _input1->info()->fixed_point_position();
- execute_window_loop(collapsed, [&](const Coordinates & id)
- {
- (*_func_q_int)(input1.ptr(), input2.ptr(), output.ptr(), _scale_exponent, fixed_point_position);
- collapsed.slide_window_slice_3D(slice_input1);
- collapsed.slide_window_slice_3D(slice_input2);
- },
- input1, input2, output);
- }
else
{
ARM_COMPUTE_ERROR_ON(_func_float == nullptr);