From ab18212dd287cc0ec9b7c1a2c72455fe75ebd13d Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Mon, 9 Oct 2017 15:05:40 +0100 Subject: COMPMID-616 - Optimizing GEMMLowp on NEON intrinsics Change-Id: Ibbeff5d37249b6e8fc34ad496035a1511c9da5a3 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/94072 Tested-by: Kaizen Reviewed-by: Pablo Tello --- .../kernels/NEGEMMLowpMatrixMultiplyKernel.cpp | 395 ++++++--------------- 1 file changed, 106 insertions(+), 289 deletions(-) (limited to 'src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.cpp') diff --git a/src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.cpp b/src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.cpp index cbba4461a2..3e614a8bfc 100644 --- a/src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.cpp +++ b/src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016, 2017 ARM Limited. + * Copyright (c) 2017 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -23,6 +23,7 @@ */ #include "arm_compute/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.h" +#include "arm_compute/core/AccessWindowStatic.h" #include "arm_compute/core/Error.h" #include "arm_compute/core/Helpers.h" #include "arm_compute/core/ITensor.h" @@ -45,35 +46,43 @@ class Coordinates; } // namespace arm_compute NEGEMMLowpMatrixMultiplyKernel::NEGEMMLowpMatrixMultiplyKernel() - : _input0(nullptr), _input1(nullptr), _output(nullptr), _a_offset(0), _b_offset(0), _output_offset(0), _output_mult_int(0), _shift(0) + : _input0(nullptr), _input1(nullptr), _output(nullptr), _slide_matrix_b(true) { } -void NEGEMMLowpMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output, - int32_t a_offset, int32_t b_offset, int32_t output_offset, int32_t output_mult_int, int32_t shift) +void NEGEMMLowpMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output) { ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::U8); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8); - ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1, output); - - _input0 = input0; - _input1 = input1; - _output = output; - _a_offset = a_offset; - _b_offset = b_offset; - _output_offset = output_offset; - _output_mult_int = output_mult_int; - _shift = shift; + ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32); + + // Check if matrix B should be slidden or not + // Don't slide matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2 + // This scenario can happen when the the matrix multiplication is used to perform a convolution operation + TensorShape in0_shape = input0->info()->tensor_shape(); + TensorShape in1_shape = input1->info()->tensor_shape(); + TensorShape out_shape = output->info()->tensor_shape(); + + in0_shape.collapse(2); + in1_shape.collapse(2); + out_shape.collapse(2); + + ARM_COMPUTE_ERROR_ON_MSG(in0_shape[2] != out_shape[2], "Output tensor must have the same number of batches of input0 tensor"); + ARM_COMPUTE_ERROR_ON_MSG(in1_shape[2] != 1 && in0_shape[2] != in1_shape[2], "Input1 tensor must have the same number of batches of input0 or the number of batches must be set to 1"); + + _input0 = input0; + _input1 = input1; + _output = output; + _slide_matrix_b = in1_shape[2] != 1; constexpr unsigned int num_elems_processed_per_iteration_x = 16; constexpr unsigned int num_elems_processed_per_iteration_y = 4; Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y)); - AccessWindowRectangle output_access(output->info(), 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y); - AccessWindowHorizontal in0_access(input0->info(), 0, num_elems_processed_per_iteration_x); + AccessWindowStatic in0_access(input0->info(), 0, 0, ceil_to_multiple(input0->info()->dimension(0), 8), input0->info()->dimension(1)); AccessWindowHorizontal in1_access(input1->info(), 0, num_elems_processed_per_iteration_x); + AccessWindowRectangle output_access(output->info(), 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y); update_window_and_padding(win, in0_access, in1_access, output_access); @@ -88,337 +97,145 @@ void NEGEMMLowpMatrixMultiplyKernel::run(const Window &window, const ThreadInfo ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window); const size_t in_b_stride = _input1->info()->strides_in_bytes()[1]; - const size_t out_stride = _output->info()->strides_in_bytes()[1]; + const size_t out_stride = _output->info()->strides_in_bytes()[1] / _output->info()->element_size(); - /* Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix */ + // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix Window win_a(window); win_a.set(Window::DimX, Window::Dimension(0, 0, 0)); - win_a.set(Window::DimY, Window::Dimension(window.y().start() >> 2, window.y().end() >> 2, 1)); + win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, window.y().end() / 4, 1)); - /* Set step_x and step_y for matrix B. Scale by a factor of 16 the X range as the input transposed matrix A has 16 times less the cols of the output matrix */ - Window win_b(window); - win_b.set(Window::DimX, Window::Dimension(window.x().start() >> 4, window.x().end() >> 4, in_b_stride)); + // Set step_x and step_y for matrix B. Scale by a factor of 16 the X range as the input transposed matrix A has 16 times less the columns of the output matrix + Window win_b; + // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2 + // This scenario can happen when the the matrix multiplication is used to perform a convolution operation + if(_slide_matrix_b) + { + win_b = window; + } + win_b.set(Window::DimX, Window::Dimension(window.x().start() / 16, window.x().end() / 16, in_b_stride)); win_b.set(Window::DimY, Window::Dimension(0, 0, 0)); - /* The step x and step y for the output matrix has been already set using in configure() */ + // The step x and step y for the output matrix has been already set using in configure() Iterator ina(_input0, win_a); Iterator inb(_input1, win_b); Iterator out(_output, window); - const int32x4_t voffset_a = vdupq_n_s32(_a_offset); - const int32x4_t voffset_b = vdupq_n_s32(_b_offset); - const int32x4_t vshiftr = vdupq_n_s32(-_shift); - const int width_b = _input1->info()->dimension(0); // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration // All the values needed for computing a single 4x4 block will be read from consecutive memory positions - execute_window_loop(window, [&](const Coordinates &) + execute_window_loop(window, [&](const Coordinates & id) { const uint8_t *mtx_a0 = ina.ptr(); const uint8_t *mtx_b0 = inb.ptr(); + // Note: Since the input are all positives, we can use uint32_t // Accumulators for the block 0 - int32x4x4_t c0 = + uint32x4x4_t c0 = { { - vdupq_n_s32(_output_offset), - vdupq_n_s32(_output_offset), - vdupq_n_s32(_output_offset), - vdupq_n_s32(_output_offset) + vdupq_n_u32(0), + vdupq_n_u32(0), + vdupq_n_u32(0), + vdupq_n_u32(0) } }; // Accumulators for the block 1 - int32x4x4_t c1 = + uint32x4x4_t c1 = { { - vdupq_n_s32(_output_offset), - vdupq_n_s32(_output_offset), - vdupq_n_s32(_output_offset), - vdupq_n_s32(_output_offset) + vdupq_n_u32(0), + vdupq_n_u32(0), + vdupq_n_u32(0), + vdupq_n_u32(0) } }; // Accumulators for the block 2 - int32x4x4_t c2 = + uint32x4x4_t c2 = { { - vdupq_n_s32(_output_offset), - vdupq_n_s32(_output_offset), - vdupq_n_s32(_output_offset), - vdupq_n_s32(_output_offset) + vdupq_n_u32(0), + vdupq_n_u32(0), + vdupq_n_u32(0), + vdupq_n_u32(0) } }; // Accumulators for the block 3 - int32x4x4_t c3 = + uint32x4x4_t c3 = { { - vdupq_n_s32(_output_offset), - vdupq_n_s32(_output_offset), - vdupq_n_s32(_output_offset), - vdupq_n_s32(_output_offset) + vdupq_n_u32(0), + vdupq_n_u32(0), + vdupq_n_u32(0), + vdupq_n_u32(0) } }; - int k = 0; - // This for loop performs 4 accumulations per iteration - for(; k <= (width_b - 64); k += 64, mtx_a0 += 16, mtx_b0 += 64) + for(int k = 0; k < width_b; k += 16, mtx_a0 += 4, mtx_b0 += 16) { - const uint8x8_t p00 = vld1_u8(mtx_a0 + 0); - const uint8x8_t p01 = vld1_u8(mtx_a0 + 8); - const uint8x8_t q00l = vld1_u8(mtx_b0 + 0); - const uint8x8_t q00h = vld1_u8(mtx_b0 + 8); - const uint8x8_t q01l = vld1_u8(mtx_b0 + 16); - const uint8x8_t q01h = vld1_u8(mtx_b0 + 24); - const uint8x8_t q02l = vld1_u8(mtx_b0 + 32); - const uint8x8_t q02h = vld1_u8(mtx_b0 + 40); - const uint8x8_t q03l = vld1_u8(mtx_b0 + 48); - const uint8x8_t q03h = vld1_u8(mtx_b0 + 56); - - const int32x4_t ia0l = vaddw_s16(voffset_a, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(p00)))); - const int32x4_t ia0h = vaddw_s16(voffset_a, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(p00)))); - const int32x4_t ia1l = vaddw_s16(voffset_a, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(p01)))); - const int32x4_t ia1h = vaddw_s16(voffset_a, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(p01)))); - - const int32x2x4_t ia0 = - { - { - vget_low_s32(ia0l), - vget_high_s32(ia0l), - vget_low_s32(ia0h), - vget_high_s32(ia0h) - } - }; - - const int32x2x4_t ia1 = - { - { - vget_low_s32(ia1l), - vget_high_s32(ia1l), - vget_low_s32(ia1h), - vget_high_s32(ia1h) - } - }; - - const int32x4x4_t ib0 = - { - { - vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(q00l)))), - vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(q00l)))), - vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(q00h)))), - vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(q00h)))) - } - }; + const uint8x8_t a00_u8 = vld1_u8(mtx_a0); + const uint8x16_t b00_u8 = vld1q_u8(mtx_b0); - const int32x4x4_t ib1 = - { - { - vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(q01l)))), - vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(q01l)))), - vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(q01h)))), - vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(q01h)))) - } - }; + // Convert a00_u8 to uint16_t and get the lower part + const uint16x4_t a00_u16 = vget_low_u16(vmovl_u8(a00_u8)); - const int32x4x4_t ib2 = + // Convert b00_u8 to int16_t + const uint16x4x4_t b00_u16 = { { - vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(q02l)))), - vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(q02l)))), - vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(q02h)))), - vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(q02h)))) - } - }; - - const int32x4x4_t ib3 = - { - { - vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(q03l)))), - vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(q03l)))), - vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(q03h)))), - vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(q03h)))) - } - }; - - // 4x4 block 0 - Accumulation 0 - c0.val[0] = vmlaq_lane_s32(c0.val[0], ib0.val[0], ia0.val[0], 0); - c0.val[1] = vmlaq_lane_s32(c0.val[1], ib0.val[0], ia0.val[0], 1); - c0.val[2] = vmlaq_lane_s32(c0.val[2], ib0.val[0], ia0.val[1], 0); - c0.val[3] = vmlaq_lane_s32(c0.val[3], ib0.val[0], ia0.val[1], 1); - // 4x4 block 0 - Accumulation 1 - c0.val[0] = vmlaq_lane_s32(c0.val[0], ib1.val[0], ia0.val[2], 0); - c0.val[1] = vmlaq_lane_s32(c0.val[1], ib1.val[0], ia0.val[2], 1); - c0.val[2] = vmlaq_lane_s32(c0.val[2], ib1.val[0], ia0.val[3], 0); - c0.val[3] = vmlaq_lane_s32(c0.val[3], ib1.val[0], ia0.val[3], 1); - // 4x4 block 0 - Accumulation 2 - c0.val[0] = vmlaq_lane_s32(c0.val[0], ib2.val[0], ia1.val[0], 0); - c0.val[1] = vmlaq_lane_s32(c0.val[1], ib2.val[0], ia1.val[0], 1); - c0.val[2] = vmlaq_lane_s32(c0.val[2], ib2.val[0], ia1.val[1], 0); - c0.val[3] = vmlaq_lane_s32(c0.val[3], ib2.val[0], ia1.val[1], 1); - // 4x4 block 0 - Accumulation 3 - c0.val[0] = vmlaq_lane_s32(c0.val[0], ib3.val[0], ia1.val[2], 0); - c0.val[1] = vmlaq_lane_s32(c0.val[1], ib3.val[0], ia1.val[2], 1); - c0.val[2] = vmlaq_lane_s32(c0.val[2], ib3.val[0], ia1.val[3], 0); - c0.val[3] = vmlaq_lane_s32(c0.val[3], ib3.val[0], ia1.val[3], 1); - - // 4x4 block 1 - Accumulation 0 - c1.val[0] = vmlaq_lane_s32(c1.val[0], ib0.val[1], ia0.val[0], 0); - c1.val[1] = vmlaq_lane_s32(c1.val[1], ib0.val[1], ia0.val[0], 1); - c1.val[2] = vmlaq_lane_s32(c1.val[2], ib0.val[1], ia0.val[1], 0); - c1.val[3] = vmlaq_lane_s32(c1.val[3], ib0.val[1], ia0.val[1], 1); - // 4x4 block 1 - Accumulation 1 - c1.val[0] = vmlaq_lane_s32(c1.val[0], ib1.val[1], ia0.val[2], 0); - c1.val[1] = vmlaq_lane_s32(c1.val[1], ib1.val[1], ia0.val[2], 1); - c1.val[2] = vmlaq_lane_s32(c1.val[2], ib1.val[1], ia0.val[3], 0); - c1.val[3] = vmlaq_lane_s32(c1.val[3], ib1.val[1], ia0.val[3], 1); - // 4x4 block 1 - Accumulation 2 - c1.val[0] = vmlaq_lane_s32(c1.val[0], ib2.val[1], ia1.val[0], 0); - c1.val[1] = vmlaq_lane_s32(c1.val[1], ib2.val[1], ia1.val[0], 1); - c1.val[2] = vmlaq_lane_s32(c1.val[2], ib2.val[1], ia1.val[1], 0); - c1.val[3] = vmlaq_lane_s32(c1.val[3], ib2.val[1], ia1.val[1], 1); - // 4x4 block 1 - Accumulation 3 - c1.val[0] = vmlaq_lane_s32(c1.val[0], ib3.val[1], ia1.val[2], 0); - c1.val[1] = vmlaq_lane_s32(c1.val[1], ib3.val[1], ia1.val[2], 1); - c1.val[2] = vmlaq_lane_s32(c1.val[2], ib3.val[1], ia1.val[3], 0); - c1.val[3] = vmlaq_lane_s32(c1.val[3], ib3.val[1], ia1.val[3], 1); - - // 4x4 block 2 - Accumulation 0 - c2.val[0] = vmlaq_lane_s32(c2.val[0], ib0.val[2], ia0.val[0], 0); - c2.val[1] = vmlaq_lane_s32(c2.val[1], ib0.val[2], ia0.val[0], 1); - c2.val[2] = vmlaq_lane_s32(c2.val[2], ib0.val[2], ia0.val[1], 0); - c2.val[3] = vmlaq_lane_s32(c2.val[3], ib0.val[2], ia0.val[1], 1); - // 4x4 block 2 - Accumulation 1 - c2.val[0] = vmlaq_lane_s32(c2.val[0], ib1.val[2], ia0.val[2], 0); - c2.val[1] = vmlaq_lane_s32(c2.val[1], ib1.val[2], ia0.val[2], 1); - c2.val[2] = vmlaq_lane_s32(c2.val[2], ib1.val[2], ia0.val[3], 0); - c2.val[3] = vmlaq_lane_s32(c2.val[3], ib1.val[2], ia0.val[3], 1); - // 4x4 block 2 - Accumulation 2 - c2.val[0] = vmlaq_lane_s32(c2.val[0], ib2.val[2], ia1.val[0], 0); - c2.val[1] = vmlaq_lane_s32(c2.val[1], ib2.val[2], ia1.val[0], 1); - c2.val[2] = vmlaq_lane_s32(c2.val[2], ib2.val[2], ia1.val[1], 0); - c2.val[3] = vmlaq_lane_s32(c2.val[3], ib2.val[2], ia1.val[1], 1); - // 4x4 block 2 - Accumulation 3 - c2.val[0] = vmlaq_lane_s32(c2.val[0], ib3.val[2], ia1.val[2], 0); - c2.val[1] = vmlaq_lane_s32(c2.val[1], ib3.val[2], ia1.val[2], 1); - c2.val[2] = vmlaq_lane_s32(c2.val[2], ib3.val[2], ia1.val[3], 0); - c2.val[3] = vmlaq_lane_s32(c2.val[3], ib3.val[2], ia1.val[3], 1); - - // 4x4 block 3 - Accumulation 0 - c3.val[0] = vmlaq_lane_s32(c3.val[0], ib0.val[3], ia0.val[0], 0); - c3.val[1] = vmlaq_lane_s32(c3.val[1], ib0.val[3], ia0.val[0], 1); - c3.val[2] = vmlaq_lane_s32(c3.val[2], ib0.val[3], ia0.val[1], 0); - c3.val[3] = vmlaq_lane_s32(c3.val[3], ib0.val[3], ia0.val[1], 1); - // 4x4 block 3 - Accumulation 1 - c3.val[0] = vmlaq_lane_s32(c3.val[0], ib1.val[3], ia0.val[2], 0); - c3.val[1] = vmlaq_lane_s32(c3.val[1], ib1.val[3], ia0.val[2], 1); - c3.val[2] = vmlaq_lane_s32(c3.val[2], ib1.val[3], ia0.val[3], 0); - c3.val[3] = vmlaq_lane_s32(c3.val[3], ib1.val[3], ia0.val[3], 1); - // 4x4 block 3 - Accumulation 2 - c3.val[0] = vmlaq_lane_s32(c3.val[0], ib2.val[3], ia1.val[0], 0); - c3.val[1] = vmlaq_lane_s32(c3.val[1], ib2.val[3], ia1.val[0], 1); - c3.val[2] = vmlaq_lane_s32(c3.val[2], ib2.val[3], ia1.val[1], 0); - c3.val[3] = vmlaq_lane_s32(c3.val[3], ib2.val[3], ia1.val[1], 1); - // 4x4 block 3 - Accumulation 3 - c3.val[0] = vmlaq_lane_s32(c3.val[0], ib3.val[3], ia1.val[2], 0); - c3.val[1] = vmlaq_lane_s32(c3.val[1], ib3.val[3], ia1.val[2], 1); - c3.val[2] = vmlaq_lane_s32(c3.val[2], ib3.val[3], ia1.val[3], 0); - c3.val[3] = vmlaq_lane_s32(c3.val[3], ib3.val[3], ia1.val[3], 1); - } - - // This for loop handles the left-over accumulations - for(; k < width_b; k += 16, mtx_a0 += 4, mtx_b0 += 16) - { - const uint8x8_t p00 = vld1_u8(mtx_a0); - const uint8x8_t q00l = vld1_u8(mtx_b0); - const uint8x8_t q00h = vld1_u8(mtx_b0 + 8); - - const int32x4_t ia0 = vaddw_s16(voffset_a, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(p00)))); - - const int32x2x2_t ia = - { - { - vget_low_s32(ia0), - vget_high_s32(ia0) - } - }; - - const int32x4x4_t ib0 = - { - { - vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(q00l)))), - vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(q00l)))), - vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(q00h)))), - vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(q00h)))) + vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))), + vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))), + vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))), + vget_high_u16(vmovl_u8(vget_high_u8(b00_u8))) } }; // 4x4 block 0 - c0.val[0] = vmlaq_lane_s32(c0.val[0], ib0.val[0], ia.val[0], 0); - c0.val[1] = vmlaq_lane_s32(c0.val[1], ib0.val[0], ia.val[0], 1); - c0.val[2] = vmlaq_lane_s32(c0.val[2], ib0.val[0], ia.val[1], 0); - c0.val[3] = vmlaq_lane_s32(c0.val[3], ib0.val[0], ia.val[1], 1); + c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16, 0); + c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16, 0); + c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16, 0); + c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16, 0); // 4x4 block 1 - c1.val[0] = vmlaq_lane_s32(c1.val[0], ib0.val[1], ia.val[0], 0); - c1.val[1] = vmlaq_lane_s32(c1.val[1], ib0.val[1], ia.val[0], 1); - c1.val[2] = vmlaq_lane_s32(c1.val[2], ib0.val[1], ia.val[1], 0); - c1.val[3] = vmlaq_lane_s32(c1.val[3], ib0.val[1], ia.val[1], 1); + c1.val[0] = vmlal_lane_u16(c1.val[0], b00_u16.val[0], a00_u16, 1); + c1.val[1] = vmlal_lane_u16(c1.val[1], b00_u16.val[1], a00_u16, 1); + c1.val[2] = vmlal_lane_u16(c1.val[2], b00_u16.val[2], a00_u16, 1); + c1.val[3] = vmlal_lane_u16(c1.val[3], b00_u16.val[3], a00_u16, 1); // 4x4 block 2 - c2.val[0] = vmlaq_lane_s32(c2.val[0], ib0.val[2], ia.val[0], 0); - c2.val[1] = vmlaq_lane_s32(c2.val[1], ib0.val[2], ia.val[0], 1); - c2.val[2] = vmlaq_lane_s32(c2.val[2], ib0.val[2], ia.val[1], 0); - c2.val[3] = vmlaq_lane_s32(c2.val[3], ib0.val[2], ia.val[1], 1); + c2.val[0] = vmlal_lane_u16(c2.val[0], b00_u16.val[0], a00_u16, 2); + c2.val[1] = vmlal_lane_u16(c2.val[1], b00_u16.val[1], a00_u16, 2); + c2.val[2] = vmlal_lane_u16(c2.val[2], b00_u16.val[2], a00_u16, 2); + c2.val[3] = vmlal_lane_u16(c2.val[3], b00_u16.val[3], a00_u16, 2); // 4x4 block 3 - c3.val[0] = vmlaq_lane_s32(c3.val[0], ib0.val[3], ia.val[0], 0); - c3.val[1] = vmlaq_lane_s32(c3.val[1], ib0.val[3], ia.val[0], 1); - c3.val[2] = vmlaq_lane_s32(c3.val[2], ib0.val[3], ia.val[1], 0); - c3.val[3] = vmlaq_lane_s32(c3.val[3], ib0.val[3], ia.val[1], 1); + c3.val[0] = vmlal_lane_u16(c3.val[0], b00_u16.val[0], a00_u16, 3); + c3.val[1] = vmlal_lane_u16(c3.val[1], b00_u16.val[1], a00_u16, 3); + c3.val[2] = vmlal_lane_u16(c3.val[2], b00_u16.val[2], a00_u16, 3); + c3.val[3] = vmlal_lane_u16(c3.val[3], b00_u16.val[3], a00_u16, 3); } - c0.val[0] = vshlq_s32(vmulq_n_s32(c0.val[0], _output_mult_int), vshiftr); - c0.val[1] = vshlq_s32(vmulq_n_s32(c0.val[1], _output_mult_int), vshiftr); - c0.val[2] = vshlq_s32(vmulq_n_s32(c0.val[2], _output_mult_int), vshiftr); - c0.val[3] = vshlq_s32(vmulq_n_s32(c0.val[3], _output_mult_int), vshiftr); - - c1.val[0] = vshlq_s32(vmulq_n_s32(c1.val[0], _output_mult_int), vshiftr); - c1.val[1] = vshlq_s32(vmulq_n_s32(c1.val[1], _output_mult_int), vshiftr); - c1.val[2] = vshlq_s32(vmulq_n_s32(c1.val[2], _output_mult_int), vshiftr); - c1.val[3] = vshlq_s32(vmulq_n_s32(c1.val[3], _output_mult_int), vshiftr); - - c2.val[0] = vshlq_s32(vmulq_n_s32(c2.val[0], _output_mult_int), vshiftr); - c2.val[1] = vshlq_s32(vmulq_n_s32(c2.val[1], _output_mult_int), vshiftr); - c2.val[2] = vshlq_s32(vmulq_n_s32(c2.val[2], _output_mult_int), vshiftr); - c2.val[3] = vshlq_s32(vmulq_n_s32(c2.val[3], _output_mult_int), vshiftr); - - c3.val[0] = vshlq_s32(vmulq_n_s32(c3.val[0], _output_mult_int), vshiftr); - c3.val[1] = vshlq_s32(vmulq_n_s32(c3.val[1], _output_mult_int), vshiftr); - c3.val[2] = vshlq_s32(vmulq_n_s32(c3.val[2], _output_mult_int), vshiftr); - c3.val[3] = vshlq_s32(vmulq_n_s32(c3.val[3], _output_mult_int), vshiftr); - - const uint8x16x4_t r = - { - { - vcombine_u8(vqmovun_s16(vcombine_s16(vqmovn_s32(c0.val[0]), vqmovn_s32(c1.val[0]))), - vqmovun_s16(vcombine_s16(vqmovn_s32(c2.val[0]), vqmovn_s32(c3.val[0])))), - vcombine_u8(vqmovun_s16(vcombine_s16(vqmovn_s32(c0.val[1]), vqmovn_s32(c1.val[1]))), - vqmovun_s16(vcombine_s16(vqmovn_s32(c2.val[1]), vqmovn_s32(c3.val[1])))), - vcombine_u8(vqmovun_s16(vcombine_s16(vqmovn_s32(c0.val[2]), vqmovn_s32(c1.val[2]))), - vqmovun_s16(vcombine_s16(vqmovn_s32(c2.val[2]), vqmovn_s32(c3.val[2])))), - vcombine_u8(vqmovun_s16(vcombine_s16(vqmovn_s32(c0.val[3]), vqmovn_s32(c1.val[3]))), - vqmovun_s16(vcombine_s16(vqmovn_s32(c2.val[3]), vqmovn_s32(c3.val[3])))) - } - }; - - uint8_t *const mtx_out = out.ptr(); - vst1q_u8(mtx_out + 0 * out_stride, r.val[0]); - vst1q_u8(mtx_out + 1 * out_stride, r.val[1]); - vst1q_u8(mtx_out + 2 * out_stride, r.val[2]); - vst1q_u8(mtx_out + 3 * out_stride, r.val[3]); + auto mtx_out = reinterpret_cast(out.ptr()); + vst1q_s32(mtx_out + 0 * out_stride + 0, vreinterpretq_s32_u32(c0.val[0])); + vst1q_s32(mtx_out + 0 * out_stride + 4, vreinterpretq_s32_u32(c0.val[1])); + vst1q_s32(mtx_out + 0 * out_stride + 8, vreinterpretq_s32_u32(c0.val[2])); + vst1q_s32(mtx_out + 0 * out_stride + 12, vreinterpretq_s32_u32(c0.val[3])); + vst1q_s32(mtx_out + 1 * out_stride + 0, vreinterpretq_s32_u32(c1.val[0])); + vst1q_s32(mtx_out + 1 * out_stride + 4, vreinterpretq_s32_u32(c1.val[1])); + vst1q_s32(mtx_out + 1 * out_stride + 8, vreinterpretq_s32_u32(c1.val[2])); + vst1q_s32(mtx_out + 1 * out_stride + 12, vreinterpretq_s32_u32(c1.val[3])); + vst1q_s32(mtx_out + 2 * out_stride + 0, vreinterpretq_s32_u32(c2.val[0])); + vst1q_s32(mtx_out + 2 * out_stride + 4, vreinterpretq_s32_u32(c2.val[1])); + vst1q_s32(mtx_out + 2 * out_stride + 8, vreinterpretq_s32_u32(c2.val[2])); + vst1q_s32(mtx_out + 2 * out_stride + 12, vreinterpretq_s32_u32(c2.val[3])); + vst1q_s32(mtx_out + 3 * out_stride + 0, vreinterpretq_s32_u32(c3.val[0])); + vst1q_s32(mtx_out + 3 * out_stride + 4, vreinterpretq_s32_u32(c3.val[1])); + vst1q_s32(mtx_out + 3 * out_stride + 8, vreinterpretq_s32_u32(c3.val[2])); + vst1q_s32(mtx_out + 3 * out_stride + 12, vreinterpretq_s32_u32(c3.val[3])); }, ina, inb, out); } -- cgit v1.2.1