From 9428a182911802cf6e6df6eb751a7c7eb43602f9 Mon Sep 17 00:00:00 2001 From: Michele Di Giorgio Date: Mon, 30 Mar 2020 14:10:20 +0100 Subject: COMPMID-3237: Add support for QSYMM16 into S32 NEPixelwiseMultiplicationKernel Change-Id: I8dc3348db37b041f442639ac0d072740ca639878 Signed-off-by: Michele Di Giorgio Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2960 Tested-by: Arm Jenkins Reviewed-by: Georgios Pinitas Reviewed-by: Sang-Hoon Park Comments-Addressed: Arm Jenkins --- .../kernels/NEPixelWiseMultiplicationKernel.cpp | 175 ++++++++++++++++----- 1 file changed, 140 insertions(+), 35 deletions(-) (limited to 'src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp') diff --git a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp index a87588dbb3..ca59e66293 100644 --- a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp +++ b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2019 ARM Limited. + * Copyright (c) 2016-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -24,23 +24,12 @@ #include "arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h" #include "arm_compute/core/CPP/Validate.h" -#include "arm_compute/core/Error.h" -#include "arm_compute/core/Helpers.h" -#include "arm_compute/core/IAccessWindow.h" -#include "arm_compute/core/ITensor.h" #include "arm_compute/core/NEON/NEAsymm.h" -#include "arm_compute/core/NEON/NEFixedPoint.h" #include "arm_compute/core/NEON/NESymm.h" #include "arm_compute/core/NEON/wrapper/wrapper.h" #include "arm_compute/core/TensorInfo.h" -#include "arm_compute/core/Types.h" -#include "arm_compute/core/Validate.h" #include -#include -#include -#include -#include #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #include // needed for float16_t @@ -48,8 +37,6 @@ namespace arm_compute { -class Coordinates; - namespace { const float scale255_constant = 1.f / 255.f; @@ -66,10 +53,9 @@ inline Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *i ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input1); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8), - "Output can only be U8 if both inputs are U8"); - + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, + DataType::S16, DataType::QSYMM16, + DataType::S32, DataType::F16, DataType::F32); if(is_data_type_quantized(input1->data_type()) || is_data_type_quantized(input2->data_type())) { ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2); @@ -86,6 +72,12 @@ inline Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *i const TensorShape &out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape()); ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible"); + + ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8), + "Output can only be U8 if both inputs are U8"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::S32 && (input1->data_type() != DataType::QSYMM16 || input2->data_type() != DataType::QSYMM16), + "Output can only be S32 if both inputs are QSYMM16"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::S32 && scale != 1.f, "Unsupported scale for QSYMM16 inputs and S32 output"); } if(std::abs(scale - scale255_constant) < 0.00001f) @@ -266,8 +258,20 @@ void mul_saturate_QSYMM16_QSYMM16_QSYMM16_n(const void *__restrict input1_ptr, c const auto input2 = static_cast(input2_ptr); const auto output = static_cast(output_ptr); - const qsymm16x8x2_t input1_q = vld2q_s16(input1); - const qsymm16x8x2_t input2_q = vld2q_s16(input2); + const qsymm16x8x2_t input1_q = + { + { + vld1q_s16(input1), + vld1q_s16(input1 + 8), + } + }; + const qsymm16x8x2_t input2_q = + { + { + vld1q_s16(input2), + vld1q_s16(input2 + 8), + } + }; // Dequantitize inputs const float32x4x4_t in1_f32x4x4 = vdequantize(input1_q, input1_qua_info); @@ -284,7 +288,65 @@ void mul_saturate_QSYMM16_QSYMM16_QSYMM16_n(const void *__restrict input1_ptr, c }; const qsymm16x8x2_t result = vquantize_qsymm16(out_f32x4x4, tmp_qua_info); - vst2q_s16(output, result); + vst1q_s16(output, result.val[0]); + vst1q_s16(output + 8, result.val[1]); +} + +void mul_QSYMM16_QSYMM16_S32_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int scale) +{ + ARM_COMPUTE_UNUSED(scale); + const auto input1 = static_cast(input1_ptr); + const auto input2 = static_cast(input2_ptr); + const auto output = static_cast(output_ptr); + + const qsymm16x8x2_t input1_q = + { + { + vld1q_s16(input1), + vld1q_s16(input1 + 8), + } + }; + const qsymm16x8x2_t input2_q = + { + { + vld1q_s16(input2), + vld1q_s16(input2 + 8), + } + }; + + const int32x4x4_t in1_s32 = + { + { + vmovl_s16(vget_low_s16(input1_q.val[0])), + vmovl_s16(vget_high_s16(input1_q.val[0])), + vmovl_s16(vget_low_s16(input1_q.val[1])), + vmovl_s16(vget_high_s16(input1_q.val[1])), + } + }; + const int32x4x4_t in2_s32 = + { + { + vmovl_s16(vget_low_s16(input2_q.val[0])), + vmovl_s16(vget_high_s16(input2_q.val[0])), + vmovl_s16(vget_low_s16(input2_q.val[1])), + vmovl_s16(vget_high_s16(input2_q.val[1])), + } + }; + + const int32x4x4_t result = + { + { + vmulq_s32(in1_s32.val[0], in2_s32.val[0]), + vmulq_s32(in1_s32.val[1], in2_s32.val[1]), + vmulq_s32(in1_s32.val[2], in2_s32.val[2]), + vmulq_s32(in1_s32.val[3], in2_s32.val[3]), + } + }; + + vst1q_s32(output, result.val[0]); + vst1q_s32(output + 4, result.val[1]); + vst1q_s32(output + 8, result.val[2]); + vst1q_s32(output + 12, result.val[3]); } template @@ -412,11 +474,24 @@ void mul_S16_S16_S16_n(const void *__restrict input1_ptr, const void *__restrict const auto input2 = static_cast(input2_ptr); const auto output = static_cast(output_ptr); - const int16x8x2_t ta1 = vld2q_s16(input1); - const int16x8x2_t ta2 = vld2q_s16(input2); + const int16x8x2_t ta1 = + { + { + vld1q_s16(input1), + vld1q_s16(input1 + 8), + } + }; + const int16x8x2_t ta2 = + { + { + vld1q_s16(input2), + vld1q_s16(input2 + 8), + } + }; const int16x8x2_t result = mul_S16_S16_S16_n_k(ta1, ta2, n); - vst2q_s16(output, result); + vst1q_s16(output, result.val[0]); + vst1q_s16(output + 8, result.val[1]); } void mul_F32_F32_F32_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, float scale) @@ -472,11 +547,23 @@ void c_mul_F32_F32_F32_n(const void *__restrict input1_ptr, const void *__restri void mul_F16_F16_F16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, float scale) { #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - const auto input1 = static_cast(input1_ptr); - const auto input2 = static_cast(input2_ptr); - const auto output = static_cast(output_ptr); - const float16x8x2_t ta1 = vld2q_f16(input1); - const float16x8x2_t ta2 = vld2q_f16(input2); + const auto input1 = static_cast(input1_ptr); + const auto input2 = static_cast(input2_ptr); + const auto output = static_cast(output_ptr); + const float16x8x2_t ta1 = + { + { + vld1q_f16(input1), + vld1q_f16(input1 + 8), + } + }; + const float16x8x2_t ta2 = + { + { + vld1q_f16(input2), + vld1q_f16(input2 + 8), + } + }; const float16x8_t scale_vec = vdupq_n_f16(scale); const float16x8x2_t result = { @@ -485,7 +572,8 @@ void mul_F16_F16_F16_n(const void *__restrict input1_ptr, const void *__restrict vmulq_f16(vmulq_f16(ta1.val[1], ta2.val[1]), scale_vec), } }; - vst2q_f16(output, result); + vst1q_f16(output, result.val[0]); + vst1q_f16(output + 8, result.val[1]); #else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ ARM_COMPUTE_UNUSED(input1_ptr); ARM_COMPUTE_UNUSED(input2_ptr); @@ -550,8 +638,20 @@ void mul_S16_U8_S16_n(const void *__restrict input1_ptr, const void *__restrict const auto input2 = static_cast(input2_ptr); const auto output = static_cast(output_ptr); - const int16x8x2_t ta1 = vld2q_s16(input1); - const uint8x8x2_t ta2u = vld2_u8(input2); + const int16x8x2_t ta1 = + { + { + vld1q_s16(input1), + vld1q_s16(input1 + 8), + } + }; + const uint8x8x2_t ta2u = + { + { + vld1_u8(input2), + vld1_u8(input2 + 8), + } + }; const int16x8x2_t ta2 = { { @@ -562,7 +662,8 @@ void mul_S16_U8_S16_n(const void *__restrict input1_ptr, const void *__restrict const int16x8x2_t result = mul_S16_S16_S16_n_k(ta1, ta2, n); - vst2q_s16(output, result); + vst1q_s16(output, result.val[0]); + vst1q_s16(output + 8, result.val[1]); } template @@ -629,10 +730,14 @@ void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITe { _func_quantized = &mul_saturate_QASYMM8_SIGNED_QASYMM8_SIGNED_QASYMM8_SIGNED_n; } - else if(dt_input1 == DataType::QSYMM16 && dt_input2 == DataType::QSYMM16) + else if(dt_input1 == DataType::QSYMM16 && dt_input2 == DataType::QSYMM16 && dt_output == DataType::QSYMM16) { _func_quantized = &mul_saturate_QSYMM16_QSYMM16_QSYMM16_n; } + else if(dt_input1 == DataType::QSYMM16 && dt_input2 == DataType::QSYMM16 && dt_output == DataType::S32) + { + _func_int = &mul_QSYMM16_QSYMM16_S32_n; + } else if(DataType::U8 == dt_input1 && DataType::U8 == dt_input2 && DataType::U8 == dt_output) { if(is_scale_255) @@ -750,7 +855,7 @@ void NEPixelWiseMultiplicationKernel::run(const Window &window, const ThreadInfo Iterator input2(_input2, slice_input2); Iterator output(_output, slice); - if(is_data_type_quantized(_input1->info()->data_type())) + if((_run_optimized_qasymm8) || (_func_quantized != nullptr)) { if(_run_optimized_qasymm8) { -- cgit v1.2.1