From 79fa9a22022824735986f74557bf38095eb2284d Mon Sep 17 00:00:00 2001 From: Manuel Bottini Date: Fri, 22 Feb 2019 17:54:22 +0000 Subject: COMPMID-2009: Add support for QASYMM8 in NEPixelWiseMultiplicationKernel Change-Id: I58536e945d069c96a065b82cc14960f54afc6e1a Signed-off-by: Manuel Bottini Reviewed-on: https://review.mlplatform.org/c/781 Tested-by: Arm Jenkins Reviewed-by: Pablo Marquez --- .../kernels/NEPixelWiseMultiplicationKernel.cpp | 90 ++++++++++++++++++---- 1 file changed, 73 insertions(+), 17 deletions(-) (limited to 'src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp') diff --git a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp index a4f51436b4..e3166e02b6 100644 --- a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp +++ b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2018 ARM Limited. + * Copyright (c) 2016-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -28,8 +28,10 @@ #include "arm_compute/core/Helpers.h" #include "arm_compute/core/IAccessWindow.h" #include "arm_compute/core/ITensor.h" +#include "arm_compute/core/NEON/NEAsymm.h" #include "arm_compute/core/NEON/NEFixedPoint.h" #include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/Types.h" #include "arm_compute/core/Validate.h" #include @@ -42,12 +44,9 @@ #include // needed for float16_t #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ -using namespace arm_compute; - namespace arm_compute { class Coordinates; -} // namespace arm_compute namespace { @@ -63,15 +62,29 @@ inline Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *i ARM_COMPUTE_UNUSED(rounding_policy); ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input1); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QASYMM8, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QASYMM8, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QASYMM8, DataType::S16, DataType::F16, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8), "Output can only be U8 if both inputs are U8"); - const TensorShape &out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape()); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->data_type() == DataType::QASYMM8 && input2->data_type() != DataType::QASYMM8, + "Input2 must be QASYMM8 if both input1 is QASYMM8"); + + ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->data_type() == DataType::QASYMM8 && input2->data_type() == DataType::QASYMM8 && overflow_policy == ConvertPolicy::WRAP, + "ConvertPolicy cannot be WRAP if datatype is QASYMM8"); + + if(output->total_size() > 0) + { + if(output->data_type() == DataType::QASYMM8) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2, output); + } + + const TensorShape &out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape()); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible"); + } if(std::abs(scale - scale255_constant) < 0.00001f) { @@ -159,6 +172,34 @@ inline uint16x8_t scale255_U16_U16(uint16x8_t in) return vreinterpretq_u16_s16(vcombine_s16(vmovn_s32(tmp_s2), vmovn_s32(tmp_s1))); } +void mul_saturate_QASYMM8_QASYMM8_QASYMM8_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, float scale, + const QuantizationInfo &input1_qua_info, const QuantizationInfo &input2_qua_info, const QuantizationInfo &output_qua_info) +{ + const auto input1 = static_cast(input1_ptr); + const auto input2 = static_cast(input2_ptr); + const auto output = static_cast(output_ptr); + + const qasymm8x16_t input1_q = vld1q_u8(input1); + const qasymm8x16_t input2_q = vld1q_u8(input2); + + // Dequantitize inputs + const float32x4x4_t in1_f32x4x4 = vdequantize(input1_q, input1_qua_info); + const float32x4x4_t in2_f32x4x4 = vdequantize(input2_q, input2_qua_info); + + const QuantizationInfo tmp_qua_info = QuantizationInfo(output_qua_info.scale / scale, output_qua_info.offset); + + const float32x4x4_t out_f32x4x4 = + { + vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]), + vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]), + vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]), + vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]) + }; + + const uint8x16_t result = vquantize(out_f32x4x4, tmp_qua_info); + vst1q_u8(output, result); +} + template void mul_U8_U8_U8_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int n) { @@ -291,7 +332,6 @@ void mul_S16_S16_S16_n(const void *__restrict input1_ptr, const void *__restrict vst2q_s16(output, result); } -template void mul_F32_F32_F32_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, float scale) { const auto input1 = static_cast(input1_ptr); @@ -313,7 +353,6 @@ void mul_F32_F32_F32_n(const void *__restrict input1_ptr, const void *__restrict vst4q_f32(output, result); } -template void mul_F16_F16_F16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, float scale) { #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -419,7 +458,7 @@ void mul_U8_S16_S16_n(const void *__restrict input1_ptr, const void *__restrict } // namespace NEPixelWiseMultiplicationKernel::NEPixelWiseMultiplicationKernel() - : _func_float(nullptr), _func_int(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr), _scale{ 0 }, _scale_exponent{ 0 } + : _func_float(nullptr), _func_int(nullptr), _func_qasymm8(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr), _scale{ 0 }, _scale_exponent{ 0 } { } @@ -439,6 +478,7 @@ void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITe _output = output; _scale = scale; _scale_exponent = 0; + _func_qasymm8 = nullptr; _func_int = nullptr; _func_float = nullptr; @@ -464,7 +504,11 @@ void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITe const DataType dt_output = output->info()->data_type(); const bool is_sat = (overflow_policy == ConvertPolicy::SATURATE); - if(DataType::U8 == dt_input1 && DataType::U8 == dt_input2 && DataType::U8 == dt_output) + if(dt_input1 == DataType::QASYMM8 && dt_input2 == DataType::QASYMM8) + { + _func_qasymm8 = &mul_saturate_QASYMM8_QASYMM8_QASYMM8_n; + } + else if(DataType::U8 == dt_input1 && DataType::U8 == dt_input2 && DataType::U8 == dt_output) { if(is_scale_255) { @@ -521,12 +565,12 @@ void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITe } else if(DataType::F16 == dt_input1 && DataType::F16 == dt_input2 && DataType::F16 == dt_output) { - _func_float = &mul_F16_F16_F16_n; + _func_float = &mul_F16_F16_F16_n; _func_int = nullptr; } else if(DataType::F32 == dt_input1 && DataType::F32 == dt_input2 && DataType::F32 == dt_output) { - _func_float = &mul_F32_F32_F32_n; + _func_float = &mul_F32_F32_F32_n; _func_int = nullptr; } else @@ -581,7 +625,18 @@ void NEPixelWiseMultiplicationKernel::run(const Window &window, const ThreadInfo Iterator input2(_input2, slice_input2); Iterator output(_output, slice); - if(_func_int != nullptr) + if(_func_qasymm8 != nullptr) + { + execute_window_loop(collapsed, [&](const Coordinates & id) + { + (*_func_qasymm8)(input1.ptr(), input2.ptr(), output.ptr(), _scale, + _input1->info()->quantization_info(), _input2->info()->quantization_info(), _output->info()->quantization_info()); + collapsed.slide_window_slice_3D(slice_input1); + collapsed.slide_window_slice_3D(slice_input2); + }, + input1, input2, output); + } + else if(_func_int != nullptr) { execute_window_loop(collapsed, [&](const Coordinates & id) { @@ -610,3 +665,4 @@ BorderSize NEPixelWiseMultiplicationKernel::border_size() const const unsigned int border = std::min(num_elems_processed_per_iteration - 1U, replicateSize); return BorderSize(0, border, 0, 0); } +} // namespace arm_compute -- cgit v1.2.1