aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
diff options
context:
space:
mode:
authorManuel Bottini <manuel.bottini@arm.com>2019-06-26 15:17:09 +0100
committerManuel Bottini <manuel.bottini@arm.com>2019-07-03 12:46:08 +0000
commit7bb56c6337997281df10fa28ad7924c921b920eb (patch)
treeaf1ee9244c7c0f9265bb6d075816b18fac2f66df /src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
parent6b9f388f719dc9ff1181c9a43a41140f19e15ec8 (diff)
downloadComputeLibrary-7bb56c6337997281df10fa28ad7924c921b920eb.tar.gz
COMPMID-2409: Add QSYMM16 support for PixelWiseMultiplication for NEON
Change-Id: Idfd3b45857201d5143242f9517d3353150b2c923 Signed-off-by: Manuel Bottini <manuel.bottini@arm.com> Reviewed-on: https://review.mlplatform.org/c/1422 Reviewed-by: Pablo Marquez <pablo.tello@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp76
1 files changed, 63 insertions, 13 deletions
diff --git a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
index c313b23ad3..6aaac818e9 100644
--- a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
+++ b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
@@ -30,6 +30,7 @@
#include "arm_compute/core/ITensor.h"
#include "arm_compute/core/NEON/NEAsymm.h"
#include "arm_compute/core/NEON/NEFixedPoint.h"
+#include "arm_compute/core/NEON/NESymm.h"
#include "arm_compute/core/NEON/wrapper/wrapper.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
@@ -63,21 +64,30 @@ inline Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *i
ARM_COMPUTE_UNUSED(rounding_policy);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input1);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QASYMM8, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QASYMM8, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QASYMM8, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QASYMM8, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QASYMM8, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QASYMM8, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8),
"Output can only be U8 if both inputs are U8");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->data_type() == DataType::QASYMM8 && input2->data_type() != DataType::QASYMM8,
- "Input2 must be QASYMM8 if both input1 is QASYMM8");
+ "Input2 must be QASYMM8 if input1 is QASYMM8");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->data_type() == DataType::QASYMM8 && input2->data_type() == DataType::QASYMM8 && overflow_policy == ConvertPolicy::WRAP,
- "ConvertPolicy cannot be WRAP if datatype is QASYMM8");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->data_type() != DataType::QASYMM8 && input2->data_type() == DataType::QASYMM8,
+ "Input1 must be QASYMM8 if input2 is QASYMM8");
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->data_type() == DataType::QSYMM16 && input2->data_type() != DataType::QSYMM16,
+ "Input2 must be QSYMM16 if input1 is QSYMM16");
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->data_type() != DataType::QSYMM16 && input2->data_type() == DataType::QSYMM16,
+ "Input1 must be QSYMM16 if input2 is QSYMM16");
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_data_type_quantized(input1->data_type()) && overflow_policy == ConvertPolicy::WRAP,
+ "ConvertPolicy cannot be WRAP if datatype is quantized");
if(output->total_size() > 0)
{
- if(output->data_type() == DataType::QASYMM8)
+ if(is_data_type_quantized(output->data_type()))
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2, output);
}
@@ -128,6 +138,14 @@ inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *inpu
{
set_format_if_unknown(*output, Format::F16);
}
+ else if(input1->data_type() == DataType::QASYMM8)
+ {
+ set_data_type_if_unknown(*output, DataType::QASYMM8);
+ }
+ else if(input1->data_type() == DataType::QSYMM16)
+ {
+ set_data_type_if_unknown(*output, DataType::QSYMM16);
+ }
}
// Configure kernel window
@@ -201,6 +219,34 @@ void mul_saturate_QASYMM8_QASYMM8_QASYMM8_n(const void *__restrict input1_ptr, c
vst1q_u8(output, result);
}
+void mul_saturate_QSYMM16_QSYMM16_QSYMM16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, float scale,
+ const UniformQuantizationInfo &input1_qua_info, const UniformQuantizationInfo &input2_qua_info, const UniformQuantizationInfo &output_qua_info)
+{
+ const auto input1 = static_cast<const qsymm16_t *__restrict>(input1_ptr);
+ const auto input2 = static_cast<const qsymm16_t *__restrict>(input2_ptr);
+ const auto output = static_cast<qsymm16_t *__restrict>(output_ptr);
+
+ const qsymm16x8x2_t input1_q = vld2q_s16(input1);
+ const qsymm16x8x2_t input2_q = vld2q_s16(input2);
+
+ // Dequantitize inputs
+ const float32x4x4_t in1_f32x4x4 = vdequantize(input1_q, input1_qua_info);
+ const float32x4x4_t in2_f32x4x4 = vdequantize(input2_q, input2_qua_info);
+
+ const UniformQuantizationInfo tmp_qua_info = { output_qua_info.scale / scale, output_qua_info.offset };
+
+ const float32x4x4_t out_f32x4x4 =
+ {
+ vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]),
+ vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]),
+ vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]),
+ vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]),
+ };
+
+ const qsymm16x8x2_t result = vquantize_qsymm16(out_f32x4x4, tmp_qua_info);
+ vst2q_s16(output, result);
+}
+
template <bool is_scale255, bool is_sat>
void mul_U8_U8_U8_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int n)
{
@@ -488,7 +534,7 @@ void mul_U8_S16_S16_n(const void *__restrict input1_ptr, const void *__restrict
} // namespace
NEPixelWiseMultiplicationKernel::NEPixelWiseMultiplicationKernel()
- : _func_float(nullptr), _func_int(nullptr), _func_qasymm8(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr), _scale{ 0 }, _scale_exponent{ 0 }
+ : _func_float(nullptr), _func_int(nullptr), _func_quantized(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr), _scale{ 0 }, _scale_exponent{ 0 }
{
}
@@ -508,7 +554,7 @@ void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITe
_output = output;
_scale = scale;
_scale_exponent = 0;
- _func_qasymm8 = nullptr;
+ _func_quantized = nullptr;
_func_int = nullptr;
_func_float = nullptr;
@@ -536,7 +582,11 @@ void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITe
if(dt_input1 == DataType::QASYMM8 && dt_input2 == DataType::QASYMM8)
{
- _func_qasymm8 = &mul_saturate_QASYMM8_QASYMM8_QASYMM8_n;
+ _func_quantized = &mul_saturate_QASYMM8_QASYMM8_QASYMM8_n;
+ }
+ else if(dt_input1 == DataType::QSYMM16 && dt_input2 == DataType::QSYMM16)
+ {
+ _func_quantized = &mul_saturate_QSYMM16_QSYMM16_QSYMM16_n;
}
else if(DataType::U8 == dt_input1 && DataType::U8 == dt_input2 && DataType::U8 == dt_output)
{
@@ -655,12 +705,12 @@ void NEPixelWiseMultiplicationKernel::run(const Window &window, const ThreadInfo
Iterator input2(_input2, slice_input2);
Iterator output(_output, slice);
- if(_func_qasymm8 != nullptr)
+ if(is_data_type_quantized(_input1->info()->data_type()))
{
execute_window_loop(collapsed, [&](const Coordinates &)
{
- (*_func_qasymm8)(input1.ptr(), input2.ptr(), output.ptr(), _scale,
- _input1->info()->quantization_info().uniform(), _input2->info()->quantization_info().uniform(), _output->info()->quantization_info().uniform());
+ (*_func_quantized)(input1.ptr(), input2.ptr(), output.ptr(), _scale,
+ _input1->info()->quantization_info().uniform(), _input2->info()->quantization_info().uniform(), _output->info()->quantization_info().uniform());
collapsed.slide_window_slice_3D(slice_input1);
collapsed.slide_window_slice_3D(slice_input2);
},