From 4adaddbaa633a4025f29f2e0a63c7126d9d7c530 Mon Sep 17 00:00:00 2001 From: morgolock Date: Tue, 29 Sep 2020 14:24:32 +0100 Subject: COMPMID-3170: Remove padding in NEGEMMLowpMatrixMultiplyKernel Change-Id: Ie95442c6c6a145c1a45937b03cbd433bf08e36ab Signed-off-by: morgolock Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4094 Reviewed-by: Michele Di Giorgio Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins --- .../kernels/NEGEMMLowpMatrixMultiplyKernel.cpp | 337 ++++++++++++++------- 1 file changed, 227 insertions(+), 110 deletions(-) (limited to 'src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.cpp') diff --git a/src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.cpp b/src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.cpp index c5d7f10e55..f3ba2901cb 100644 --- a/src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.cpp +++ b/src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.cpp @@ -23,7 +23,6 @@ */ #include "arm_compute/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.h" -#include "arm_compute/core/AccessWindowStatic.h" #include "arm_compute/core/Error.h" #include "arm_compute/core/Helpers.h" #include "arm_compute/core/ITensor.h" @@ -32,11 +31,7 @@ #include "arm_compute/core/Utils.h" #include "arm_compute/core/Validate.h" #include "arm_compute/core/Window.h" - #include -#include -#include -#include using namespace arm_compute; @@ -44,7 +39,7 @@ namespace arm_compute { namespace { -void inline vector_matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, size_t stride_b, const Window &window) +void inline vector_matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, int width_out, size_t stride_b, const Window &window) { execute_window_loop(window, [&](const Coordinates & id) { @@ -253,15 +248,29 @@ void inline vector_matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &ou } auto vec_out = reinterpret_cast(out.ptr()); - vst1q_s32(vec_out + 0, vreinterpretq_s32_u32(c0.val[0])); - vst1q_s32(vec_out + 4, vreinterpretq_s32_u32(c0.val[1])); - vst1q_s32(vec_out + 8, vreinterpretq_s32_u32(c0.val[2])); - vst1q_s32(vec_out + 12, vreinterpretq_s32_u32(c0.val[3])); + if(id.x() < (width_out - 16)) + { + vst1q_s32(vec_out + 0, vreinterpretq_s32_u32(c0.val[0])); + vst1q_s32(vec_out + 4, vreinterpretq_s32_u32(c0.val[1])); + vst1q_s32(vec_out + 8, vreinterpretq_s32_u32(c0.val[2])); + vst1q_s32(vec_out + 12, vreinterpretq_s32_u32(c0.val[3])); + } + else + { + auto left_over = width_out - id.x(); + for(auto k = 0; k < 4 && left_over; ++k) + { + for(auto j = 0; j < 4 && left_over; ++j, --left_over) + { + *(vec_out + k * 4 + j) = c0.val[k][j]; + } + } + } }, ina, inb, out); } -void inline vector_matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, size_t stride_b, const Window &window) +void inline vector_matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, int width_out, size_t stride_b, const Window &window) { execute_window_loop(window, [&](const Coordinates & id) { @@ -469,17 +478,34 @@ void inline vector_matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &ou } auto vec_out = reinterpret_cast(out.ptr()); - vst1q_s32(vec_out + 0, c0.val[0]); - vst1q_s32(vec_out + 4, c0.val[1]); - vst1q_s32(vec_out + 8, c0.val[2]); - vst1q_s32(vec_out + 12, c0.val[3]); + if(id.x() < (width_out - 16)) + { + vst1q_s32(vec_out + 0, c0.val[0]); + vst1q_s32(vec_out + 4, c0.val[1]); + vst1q_s32(vec_out + 8, c0.val[2]); + vst1q_s32(vec_out + 12, c0.val[3]); + } + else + { + auto left_over = width_out - id.x(); + for(auto k = 0; k < 4 && left_over; ++k) + { + for(auto j = 0; j < 4 && left_over; ++j, --left_over) + { + *(vec_out + k * 4 + j) = c0.val[k][j]; + } + } + } }, ina, inb, out); } -void inline matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, size_t out_stride, const Window &window) +void inline matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, const TensorInfo &out_info, const Window &window) { - execute_window_loop(window, [&](const Coordinates &) + const auto width_out = static_cast(out_info.dimension(0)); + const auto height_out = static_cast(out_info.dimension(1)); + const size_t out_stride = out_info.strides_in_bytes()[1] / out_info.element_size(); + execute_window_loop(window, [&](const Coordinates & id) { const uint8_t *mtx_a0 = ina.ptr(); const uint8_t *mtx_b0 = inb.ptr(); @@ -574,32 +600,93 @@ void inline matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int } auto mtx_out = reinterpret_cast(out.ptr()); - vst1q_s32(mtx_out + 0 * out_stride + 0, vreinterpretq_s32_u32(c0.val[0])); - vst1q_s32(mtx_out + 0 * out_stride + 4, vreinterpretq_s32_u32(c0.val[1])); - vst1q_s32(mtx_out + 0 * out_stride + 8, vreinterpretq_s32_u32(c0.val[2])); - vst1q_s32(mtx_out + 0 * out_stride + 12, vreinterpretq_s32_u32(c0.val[3])); - vst1q_s32(mtx_out + 1 * out_stride + 0, vreinterpretq_s32_u32(c1.val[0])); - vst1q_s32(mtx_out + 1 * out_stride + 4, vreinterpretq_s32_u32(c1.val[1])); - vst1q_s32(mtx_out + 1 * out_stride + 8, vreinterpretq_s32_u32(c1.val[2])); - vst1q_s32(mtx_out + 1 * out_stride + 12, vreinterpretq_s32_u32(c1.val[3])); - vst1q_s32(mtx_out + 2 * out_stride + 0, vreinterpretq_s32_u32(c2.val[0])); - vst1q_s32(mtx_out + 2 * out_stride + 4, vreinterpretq_s32_u32(c2.val[1])); - vst1q_s32(mtx_out + 2 * out_stride + 8, vreinterpretq_s32_u32(c2.val[2])); - vst1q_s32(mtx_out + 2 * out_stride + 12, vreinterpretq_s32_u32(c2.val[3])); - vst1q_s32(mtx_out + 3 * out_stride + 0, vreinterpretq_s32_u32(c3.val[0])); - vst1q_s32(mtx_out + 3 * out_stride + 4, vreinterpretq_s32_u32(c3.val[1])); - vst1q_s32(mtx_out + 3 * out_stride + 8, vreinterpretq_s32_u32(c3.val[2])); - vst1q_s32(mtx_out + 3 * out_stride + 12, vreinterpretq_s32_u32(c3.val[3])); + + if(id.y() < height_out && id.x() < (width_out - 16)) + { + vst1q_s32(mtx_out + 0 * out_stride + 0, vreinterpretq_s32_u32(c0.val[0])); + vst1q_s32(mtx_out + 0 * out_stride + 4, vreinterpretq_s32_u32(c0.val[1])); + vst1q_s32(mtx_out + 0 * out_stride + 8, vreinterpretq_s32_u32(c0.val[2])); + vst1q_s32(mtx_out + 0 * out_stride + 12, vreinterpretq_s32_u32(c0.val[3])); + if(id.y() + 1 < height_out) + { + vst1q_s32(mtx_out + 1 * out_stride + 0, vreinterpretq_s32_u32(c1.val[0])); + vst1q_s32(mtx_out + 1 * out_stride + 4, vreinterpretq_s32_u32(c1.val[1])); + vst1q_s32(mtx_out + 1 * out_stride + 8, vreinterpretq_s32_u32(c1.val[2])); + vst1q_s32(mtx_out + 1 * out_stride + 12, vreinterpretq_s32_u32(c1.val[3])); + if(id.y() + 2 < height_out) + { + vst1q_s32(mtx_out + 2 * out_stride + 0, vreinterpretq_s32_u32(c2.val[0])); + vst1q_s32(mtx_out + 2 * out_stride + 4, vreinterpretq_s32_u32(c2.val[1])); + vst1q_s32(mtx_out + 2 * out_stride + 8, vreinterpretq_s32_u32(c2.val[2])); + vst1q_s32(mtx_out + 2 * out_stride + 12, vreinterpretq_s32_u32(c2.val[3])); + if(id.y() + 3 < height_out) + { + vst1q_s32(mtx_out + 3 * out_stride + 0, vreinterpretq_s32_u32(c3.val[0])); + vst1q_s32(mtx_out + 3 * out_stride + 4, vreinterpretq_s32_u32(c3.val[1])); + vst1q_s32(mtx_out + 3 * out_stride + 8, vreinterpretq_s32_u32(c3.val[2])); + vst1q_s32(mtx_out + 3 * out_stride + 12, vreinterpretq_s32_u32(c3.val[3])); + } + } + } + } + else + { + const auto left_over_value = width_out - id.x(); + auto left_over = left_over_value; + for(auto k = 0; k < 4 && left_over; ++k) + { + for(auto j = 0; j < 4 && left_over; ++j, --left_over) + { + *(mtx_out + k * 4 + j) = c0.val[k][j]; + } + } + if(id.y() + 1 < height_out) + { + left_over = left_over_value; + for(auto k = 0; k < 4 && left_over; ++k) + { + for(auto j = 0; j < 4 && left_over; ++j, --left_over) + { + *(mtx_out + out_stride + k * 4 + j) = c1.val[k][j]; + } + } + if(id.y() + 2 < height_out) + { + left_over = left_over_value; + for(auto k = 0; k < 4 && left_over; ++k) + { + for(auto j = 0; j < 4 && left_over; ++j, --left_over) + { + *(mtx_out + out_stride * 2 + k * 4 + j) = c2.val[k][j]; + } + } + if(id.y() + 3 < height_out) + { + left_over = left_over_value; + for(auto k = 0; k < 4 && left_over; ++k) + { + for(auto j = 0; j < 4 && left_over; ++j, --left_over) + { + *(mtx_out + out_stride * 3 + k * 4 + j) = c3.val[k][j]; + } + } + } + } + } + } }, ina, inb, out); } -void inline matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, size_t out_stride, const Window &window) +void inline matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, const TensorInfo &out_info, const Window &window) { + const auto width_out = static_cast(out_info.dimension(0)); + const auto height_out = static_cast(out_info.dimension(1)); + const size_t out_stride = out_info.strides_in_bytes()[1] / out_info.element_size(); // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration // All the values needed for computing a single 4x4 block will be read from consecutive memory positions - execute_window_loop(window, [&](const Coordinates &) + execute_window_loop(window, [&](const Coordinates & id) { auto *mtx_a0 = reinterpret_cast(ina.ptr()); auto *mtx_b0 = reinterpret_cast(inb.ptr()); @@ -692,32 +779,86 @@ void inline matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int c3.val[2] = vmlal_lane_s16(c3.val[2], b00_s16.val[2], a00_s16, 3); c3.val[3] = vmlal_lane_s16(c3.val[3], b00_s16.val[3], a00_s16, 3); } - auto mtx_out = reinterpret_cast(out.ptr()); - vst1q_s32(mtx_out + 0 * out_stride + 0, c0.val[0]); - vst1q_s32(mtx_out + 0 * out_stride + 4, c0.val[1]); - vst1q_s32(mtx_out + 0 * out_stride + 8, c0.val[2]); - vst1q_s32(mtx_out + 0 * out_stride + 12, c0.val[3]); - vst1q_s32(mtx_out + 1 * out_stride + 0, c1.val[0]); - vst1q_s32(mtx_out + 1 * out_stride + 4, c1.val[1]); - vst1q_s32(mtx_out + 1 * out_stride + 8, c1.val[2]); - vst1q_s32(mtx_out + 1 * out_stride + 12, c1.val[3]); - vst1q_s32(mtx_out + 2 * out_stride + 0, c2.val[0]); - vst1q_s32(mtx_out + 2 * out_stride + 4, c2.val[1]); - vst1q_s32(mtx_out + 2 * out_stride + 8, c2.val[2]); - vst1q_s32(mtx_out + 2 * out_stride + 12, c2.val[3]); - vst1q_s32(mtx_out + 3 * out_stride + 0, c3.val[0]); - vst1q_s32(mtx_out + 3 * out_stride + 4, c3.val[1]); - vst1q_s32(mtx_out + 3 * out_stride + 8, c3.val[2]); - vst1q_s32(mtx_out + 3 * out_stride + 12, c3.val[3]); + if(id.y() < height_out && id.x() < (width_out - 16)) + { + vst1q_s32(mtx_out + 0 * out_stride + 0, c0.val[0]); + vst1q_s32(mtx_out + 0 * out_stride + 4, c0.val[1]); + vst1q_s32(mtx_out + 0 * out_stride + 8, c0.val[2]); + vst1q_s32(mtx_out + 0 * out_stride + 12, c0.val[3]); + if(id.y() + 1 < height_out) + { + vst1q_s32(mtx_out + 1 * out_stride + 0, c1.val[0]); + vst1q_s32(mtx_out + 1 * out_stride + 4, c1.val[1]); + vst1q_s32(mtx_out + 1 * out_stride + 8, c1.val[2]); + vst1q_s32(mtx_out + 1 * out_stride + 12, c1.val[3]); + if(id.y() + 2 < height_out) + { + vst1q_s32(mtx_out + 2 * out_stride + 0, c2.val[0]); + vst1q_s32(mtx_out + 2 * out_stride + 4, c2.val[1]); + vst1q_s32(mtx_out + 2 * out_stride + 8, c2.val[2]); + vst1q_s32(mtx_out + 2 * out_stride + 12, c2.val[3]); + if(id.y() + 3 < height_out) + { + vst1q_s32(mtx_out + 3 * out_stride + 0, c3.val[0]); + vst1q_s32(mtx_out + 3 * out_stride + 4, c3.val[1]); + vst1q_s32(mtx_out + 3 * out_stride + 8, c3.val[2]); + vst1q_s32(mtx_out + 3 * out_stride + 12, c3.val[3]); + } + } + } + } + else if(id.y() < height_out) + { + const auto left_over_value = width_out - id.x(); + auto left_over = left_over_value; + for(auto k = 0; k < 4 && left_over; ++k) + { + for(auto j = 0; j < 4 && left_over; ++j, --left_over) + { + *(mtx_out + k * 4 + j) = c0.val[k][j]; + } + } + if(id.y() + 1 < height_out) + { + left_over = left_over_value; + for(auto k = 0; k < 4 && left_over; ++k) + { + for(auto j = 0; j < 4 && left_over; ++j, --left_over) + { + *(mtx_out + out_stride + k * 4 + j) = c1.val[k][j]; + } + } + if(id.y() + 2 < height_out) + { + left_over = left_over_value; + for(auto k = 0; k < 4 && left_over; ++k) + { + for(auto j = 0; j < 4 && left_over; ++j, --left_over) + { + *(mtx_out + out_stride * 2 + k * 4 + j) = c2.val[k][j]; + } + } + if(id.y() + 3 < height_out) + { + left_over = left_over_value; + for(auto k = 0; k < 4 && left_over; ++k) + { + for(auto j = 0; j < 4 && left_over; ++j, --left_over) + { + *(mtx_out + out_stride * 3 + k * 4 + j) = c3.val[k][j]; + } + } + } + } + } + } + }, ina, inb, out); } } // namespace -class Coordinates; -} // namespace arm_compute - namespace { Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output) @@ -748,50 +889,6 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, return Status{}; } - -std::pair validate_and_configure_window(ITensorInfo *input0, ITensorInfo *input1, ITensorInfo *output) -{ - constexpr unsigned int num_elems_processed_per_iteration_x = 16; - constexpr unsigned int num_elems_processed_per_iteration_y = 4; - - Window win; - bool window_changed = false; - - // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication - if((output->dimension(1) == 1)) - { - // Configure kernel window - win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x)); - - // We cannot read out-of-bound elements from matrix A as we use the left-over for loop - AccessWindowStatic in0_access(input0, 0, 0, input0->tensor_shape().x(), 1); - AccessWindowHorizontal in1_access(input1, 0, num_elems_processed_per_iteration_x); - AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration_x); - - window_changed = update_window_and_padding(win, in0_access, in1_access, output_access); - - Coordinates coord; - coord.set_num_dimensions(output->num_dimensions()); - output_access.set_valid_region(win, ValidRegion(coord, output->tensor_shape())); - } - else - { - win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y)); - - unsigned int num_k_iterations = ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x) / 16; - // For each iteration of "k" we increment the input pointer by 4, and we load 8 elements a the time: - AccessWindowStatic in0_access(input0, 0, 0, (num_k_iterations - 1) * 4 + 8, input0->dimension(1)); - AccessWindowStatic in1_access(input1, 0, 0, ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x), input1->dimension(1)); - AccessWindowRectangle output_access(output, 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y); - - window_changed = update_window_and_padding(win, in0_access, in1_access, output_access); - - output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape())); - } - - Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{}; - return std::make_pair(err, win); -} } // namespace NEGEMMLowpMatrixMultiplyKernel::NEGEMMLowpMatrixMultiplyKernel() @@ -812,16 +909,33 @@ void NEGEMMLowpMatrixMultiplyKernel::configure(const ITensor *input0, const ITen _output = output; _slide_matrix_b = in1_shape[2] != 1; - // Configure kernel window - auto win_config = validate_and_configure_window(input0->info(), input1->info(), output->info()); - ARM_COMPUTE_ERROR_THROW_ON(win_config.first); - INEKernel::configure(win_config.second); + constexpr unsigned int num_elems_processed_per_iteration_x = 16; + constexpr unsigned int num_elems_processed_per_iteration_y = 4; + + Window win; + + // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication + if((output->info()->dimension(1) == 1)) + { + // Configure kernel window + win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x)); + + Coordinates coord; + coord.set_num_dimensions(output->info()->num_dimensions()); + output->info()->set_valid_region(ValidRegion(coord, output->info()->tensor_shape())); + } + else + { + win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y)); + output->info()->set_valid_region(ValidRegion(Coordinates(), output->info()->tensor_shape())); + } + + INEKernel::configure(win); } Status NEGEMMLowpMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output) { ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output)); - ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input0->clone().get(), input1->clone().get(), output->clone().get()).first); return Status{}; } @@ -837,6 +951,7 @@ void NEGEMMLowpMatrixMultiplyKernel::run(const Window &window, const ThreadInfo { const auto width_matrix_a = static_cast(_input0->info()->dimension(0)); const auto width_matrix_b = static_cast(_input1->info()->dimension(0)); + const auto width_out = static_cast(_output->info()->dimension(0)); const auto in_b_stride = static_cast(_input1->info()->strides_in_bytes()[1] / data_size_from_type(_input1->info()->data_type())); // The implementation computes 16 elements per iteration @@ -872,13 +987,13 @@ void NEGEMMLowpMatrixMultiplyKernel::run(const Window &window, const ThreadInfo case DataType::S8: case DataType::QASYMM8_SIGNED: { - vector_matrix_multiply_s8(ina, inb, out, width_matrix_a, width_matrix_b, in_b_stride, window); + vector_matrix_multiply_s8(ina, inb, out, width_matrix_a, width_matrix_b, width_out, in_b_stride, window); break; } case DataType::U8: case DataType::QASYMM8: { - vector_matrix_multiply_u8(ina, inb, out, width_matrix_a, width_matrix_b, in_b_stride, window); + vector_matrix_multiply_u8(ina, inb, out, width_matrix_a, width_matrix_b, width_out, in_b_stride, window); break; } default: @@ -891,7 +1006,7 @@ void NEGEMMLowpMatrixMultiplyKernel::run(const Window &window, const ThreadInfo else { const size_t in_b_stride = _input1->info()->strides_in_bytes()[1]; - const size_t out_stride = _output->info()->strides_in_bytes()[1] / _output->info()->element_size(); + const int width_b = _input1->info()->dimension(0); // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix Window win_a(window); @@ -914,19 +1029,18 @@ void NEGEMMLowpMatrixMultiplyKernel::run(const Window &window, const ThreadInfo Iterator inb(_input1, win_b); Iterator out(_output, window); - const int width_b = _input1->info()->dimension(0); switch(_input0->info()->data_type()) { case DataType::S8: case DataType::QASYMM8_SIGNED: { - matrix_multiply_s8(ina, inb, out, width_b, out_stride, window); + matrix_multiply_s8(ina, inb, out, width_b, *_output->info(), window); break; } case DataType::U8: case DataType::QASYMM8: { - matrix_multiply_u8(ina, inb, out, width_b, out_stride, window); + matrix_multiply_u8(ina, inb, out, width_b, *_output->info(), window); break; } default: @@ -937,3 +1051,6 @@ void NEGEMMLowpMatrixMultiplyKernel::run(const Window &window, const ThreadInfo } } } +} // namespace arm_compute + + -- cgit v1.2.1