aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2019-08-30 17:50:15 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2019-08-31 15:39:27 +0100
commitbf94516db968de2f4d839786afc84840ac495ea1 (patch)
treead1b4cdebb7c3844494db9c5314f8ad8de5425b7
parent7243fc3c07230948f83a0929712f4aff0103d19c (diff)
downloadComputeLibrary-bf94516db968de2f4d839786afc84840ac495ea1.tar.gz
COMPMID-2640: Fix performance regression for Resnet101 Int8 on NEON
Change-Id: I32c8b67c5ce0918cc5603807bad80952ea2fd097 Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-on: https://review.mlplatform.org/c/1848 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
-rw-r--r--arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h1
-rw-r--r--src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp110
2 files changed, 77 insertions, 34 deletions
diff --git a/arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h b/arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h
index e2ea90a33f..a199a1188c 100644
--- a/arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h
+++ b/arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h
@@ -127,6 +127,7 @@ private:
ITensor *_output;
float _scale;
int _scale_exponent;
+ bool _run_optimized_qasymm8;
};
/** Interface for the complex pixelwise multiplication kernel. */
diff --git a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
index 711bde3a2b..1dab5d955d 100644
--- a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
+++ b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
@@ -191,8 +191,8 @@ inline uint16x8_t scale255_U16_U16(uint16x8_t in)
return vreinterpretq_u16_s16(vcombine_s16(vmovn_s32(tmp_s2), vmovn_s32(tmp_s1)));
}
-void mul_saturate_QASYMM8_QASYMM8_QASYMM8_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, float scale,
- const UniformQuantizationInfo &input1_qua_info, const UniformQuantizationInfo &input2_qua_info, const UniformQuantizationInfo &output_qua_info)
+inline void mul_saturate_QASYMM8_QASYMM8_QASYMM8_n_opt(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, float scale,
+ float32x4_t input1_vscale, int32x4_t input1_voffset, float32x4_t input2_vscale, int32x4_t input2_voffset, float32x4_t output_voffset, float32x4_t vinvscale)
{
const auto input1 = static_cast<const qasymm8_t *__restrict>(input1_ptr);
const auto input2 = static_cast<const qasymm8_t *__restrict>(input2_ptr);
@@ -202,21 +202,40 @@ void mul_saturate_QASYMM8_QASYMM8_QASYMM8_n(const void *__restrict input1_ptr, c
const qasymm8x16_t input2_q = vld1q_u8(input2);
// Dequantitize inputs
- const float32x4x4_t in1_f32x4x4 = vdequantize(input1_q, input1_qua_info);
- const float32x4x4_t in2_f32x4x4 = vdequantize(input2_q, input2_qua_info);
-
- const UniformQuantizationInfo tmp_qua_info = { output_qua_info.scale / scale, output_qua_info.offset };
-
- const float32x4x4_t out_f32x4x4 =
- {
- vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]),
- vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]),
- vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]),
- vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3])
- };
-
- const uint8x16_t result = vquantize(out_f32x4x4, tmp_qua_info);
- vst1q_u8(output, result);
+ float32x4x4_t in1_f32x4x4;
+ float32x4x4_t in2_f32x4x4;
+ in1_f32x4x4.val[0] = vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(input1_q))))), input1_voffset)), input1_vscale);
+ in1_f32x4x4.val[1] = vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_low_u8(input1_q))))), input1_voffset)), input1_vscale);
+ in1_f32x4x4.val[2] = vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_high_u8(input1_q))))), input1_voffset)), input1_vscale);
+ in1_f32x4x4.val[3] = vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_high_u8(input1_q))))), input1_voffset)), input1_vscale);
+
+ in2_f32x4x4.val[0] = vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(input2_q))))), input2_voffset)), input2_vscale);
+ in2_f32x4x4.val[1] = vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_low_u8(input2_q))))), input2_voffset)), input2_vscale);
+ in2_f32x4x4.val[2] = vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_high_u8(input2_q))))), input2_voffset)), input2_vscale);
+ in2_f32x4x4.val[3] = vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_high_u8(input2_q))))), input2_voffset)), input2_vscale);
+
+ float32x4x4_t out_f32x4x4;
+ out_f32x4x4.val[0] = vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]);
+ out_f32x4x4.val[1] = vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]);
+ out_f32x4x4.val[2] = vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]);
+ out_f32x4x4.val[3] = vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]);
+
+ int32x4x4_t rf;
+#ifdef __aarch64__
+ rf.val[0] = vcvtnq_s32_f32(vmlaq_f32(output_voffset, out_f32x4x4.val[0], vinvscale));
+ rf.val[1] = vcvtnq_s32_f32(vmlaq_f32(output_voffset, out_f32x4x4.val[1], vinvscale));
+ rf.val[2] = vcvtnq_s32_f32(vmlaq_f32(output_voffset, out_f32x4x4.val[2], vinvscale));
+ rf.val[3] = vcvtnq_s32_f32(vmlaq_f32(output_voffset, out_f32x4x4.val[3], vinvscale));
+#else //__aarch64__
+ rf.val[0] = vcvtq_s32_f32(vmlaq_f32(output_voffset, out_f32x4x4.val[0], vinvscale));
+ rf.val[1] = vcvtq_s32_f32(vmlaq_f32(output_voffset, out_f32x4x4.val[1], vinvscale));
+ rf.val[2] = vcvtq_s32_f32(vmlaq_f32(output_voffset, out_f32x4x4.val[2], vinvscale));
+ rf.val[3] = vcvtq_s32_f32(vmlaq_f32(output_voffset, out_f32x4x4.val[3], vinvscale));
+#endif //__aarch64__
+ const uint8x8_t pa = vqmovun_s16(vcombine_s16(vqmovn_s32(rf.val[0]), vqmovn_s32(rf.val[1])));
+ const uint8x8_t pb = vqmovun_s16(vcombine_s16(vqmovn_s32(rf.val[2]), vqmovn_s32(rf.val[3])));
+
+ vst1q_u8(output, vcombine_u8(pa, pb));
}
void mul_saturate_QSYMM16_QSYMM16_QSYMM16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, float scale,
@@ -534,7 +553,7 @@ void mul_U8_S16_S16_n(const void *__restrict input1_ptr, const void *__restrict
} // namespace
NEPixelWiseMultiplicationKernel::NEPixelWiseMultiplicationKernel()
- : _func_float(nullptr), _func_int(nullptr), _func_quantized(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr), _scale{ 0 }, _scale_exponent{ 0 }
+ : _func_float(nullptr), _func_int(nullptr), _func_quantized(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr), _scale{ 0 }, _scale_exponent{ 0 }, _run_optimized_qasymm8(false)
{
}
@@ -549,14 +568,15 @@ void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITe
auto win_config = validate_and_configure_window(input1->info(), input2->info(), output->info());
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- _input1 = input1;
- _input2 = input2;
- _output = output;
- _scale = scale;
- _scale_exponent = 0;
- _func_quantized = nullptr;
- _func_int = nullptr;
- _func_float = nullptr;
+ _input1 = input1;
+ _input2 = input2;
+ _output = output;
+ _scale = scale;
+ _scale_exponent = 0;
+ _func_quantized = nullptr;
+ _func_int = nullptr;
+ _func_float = nullptr;
+ _run_optimized_qasymm8 = false;
bool is_scale_255 = false;
// Check and validate scaling factor
@@ -582,7 +602,7 @@ void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITe
if(dt_input1 == DataType::QASYMM8 && dt_input2 == DataType::QASYMM8)
{
- _func_quantized = &mul_saturate_QASYMM8_QASYMM8_QASYMM8_n;
+ _run_optimized_qasymm8 = true;
}
else if(dt_input1 == DataType::QSYMM16 && dt_input2 == DataType::QSYMM16)
{
@@ -707,14 +727,36 @@ void NEPixelWiseMultiplicationKernel::run(const Window &window, const ThreadInfo
if(is_data_type_quantized(_input1->info()->data_type()))
{
- execute_window_loop(collapsed, [&](const Coordinates &)
+ if(_run_optimized_qasymm8)
{
- (*_func_quantized)(input1.ptr(), input2.ptr(), output.ptr(), _scale,
- _input1->info()->quantization_info().uniform(), _input2->info()->quantization_info().uniform(), _output->info()->quantization_info().uniform());
- ARM_COMPUTE_UNUSED(collapsed.slide_window_slice_3D(slice_input1));
- ARM_COMPUTE_UNUSED(collapsed.slide_window_slice_3D(slice_input2));
- },
- input1, input2, output);
+ const int32x4_t input1_voffset = vdupq_n_s32(_input1->info()->quantization_info().uniform().offset);
+ const float32x4_t input1_vscale = vdupq_n_f32(_input1->info()->quantization_info().uniform().scale);
+ const int32x4_t input2_voffset = vdupq_n_s32(_input2->info()->quantization_info().uniform().offset);
+ const float32x4_t input2_vscale = vdupq_n_f32(_input2->info()->quantization_info().uniform().scale);
+ const float32x4_t output_voffset = vdupq_n_f32(static_cast<float>(_output->info()->quantization_info().uniform().offset));
+ const float output_scale = _output->info()->quantization_info().uniform().scale;
+ const float32x4_t vinvscale = vdupq_n_f32(1.f / (output_scale / _scale));
+
+ execute_window_loop(collapsed, [&](const Coordinates &)
+ {
+ mul_saturate_QASYMM8_QASYMM8_QASYMM8_n_opt(input1.ptr(), input2.ptr(), output.ptr(), _scale,
+ input1_vscale, input1_voffset, input2_vscale, input2_voffset, output_voffset, vinvscale);
+ ARM_COMPUTE_UNUSED(collapsed.slide_window_slice_3D(slice_input1));
+ ARM_COMPUTE_UNUSED(collapsed.slide_window_slice_3D(slice_input2));
+ },
+ input1, input2, output);
+ }
+ else
+ {
+ execute_window_loop(collapsed, [&](const Coordinates &)
+ {
+ (*_func_quantized)(input1.ptr(), input2.ptr(), output.ptr(), _scale,
+ _input1->info()->quantization_info().uniform(), _input2->info()->quantization_info().uniform(), _output->info()->quantization_info().uniform());
+ ARM_COMPUTE_UNUSED(collapsed.slide_window_slice_3D(slice_input1));
+ ARM_COMPUTE_UNUSED(collapsed.slide_window_slice_3D(slice_input2));
+ },
+ input1, input2, output);
+ }
}
else if(_func_int != nullptr)
{