From cf9e29e3bd2fcd772c156c7866425335bfdbde6a Mon Sep 17 00:00:00 2001 From: Michele Di Giorgio Date: Thu, 8 Oct 2020 11:54:42 +0100 Subject: COMPMID-3172: Remove padding from NEGEMMMatrixMultiplyKernel Template parameter has been removed, which reduces the binary size by: - ~4 kB for armv8.2a - ~12 kB for armv8a Change-Id: Ib499a18a4980a3ee7b201507b943f900adf20a73 Signed-off-by: Michele Di Giorgio Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4122 Tested-by: Arm Jenkins Reviewed-by: Giorgio Arena Reviewed-by: Georgios Pinitas Comments-Addressed: Arm Jenkins --- .../NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp | 851 ++++++++++++--------- tests/validation/NEON/GEMM.cpp | 34 + 2 files changed, 518 insertions(+), 367 deletions(-) diff --git a/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp b/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp index 5c5367c9c1..6f74e3fc06 100644 --- a/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp +++ b/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp @@ -23,11 +23,9 @@ */ #include "arm_compute/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h" -#include "arm_compute/core/AccessWindowStatic.h" #include "arm_compute/core/CPP/Validate.h" #include "arm_compute/core/Error.h" #include "arm_compute/core/Helpers.h" -#include "arm_compute/core/IAccessWindow.h" #include "arm_compute/core/ITensor.h" #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Types.h" @@ -39,25 +37,16 @@ #include "src/core/NEON/NEFixedPoint.h" #include -#include -#include -#include - -using namespace arm_compute; namespace arm_compute { -class Coordinates; -} // namespace arm_compute - namespace { -template +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC void vector_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, const ThreadInfo &info, float alpha) { -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC const auto width_matrix_b = static_cast(output->info()->dimension(0)); - const auto in_b_stride = static_cast(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type())); + const auto in_b_stride = static_cast(input1->info()->strides_in_bytes()[1] / input1->info()->element_size()); const auto num_elems_vec_a = static_cast(input0->info()->dimension(0)); // The implementation computes 32 elements per iteration @@ -67,7 +56,7 @@ void vector_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, IT ARM_COMPUTE_ERROR_ON_MSG((window_end_x - window_start_x) % window_step_x, " (window_end_x - window_start_x) must be multiple of window_step_x"); Window win_out(window); - win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x)); + win_out.set(Window::DimX, Window::Dimension(0, 1, 1)); win_out.set(Window::DimY, Window::Dimension(0, 1, 1)); Window win_a(window); @@ -81,125 +70,174 @@ void vector_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, IT { win_b = window; } - win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x)); + win_b.set(Window::DimX, Window::Dimension(0, 1, 1)); win_b.set(Window::DimY, Window::Dimension(0, 1, 1)); Iterator ina(input0, win_a); Iterator inb(input1, win_b); Iterator out(output, win_out); + const bool multiply_alpha = !(helpers::float_ops::is_one(alpha)); + const float16x8_t alpha_f16 = vdupq_n_f16(alpha); - ARM_COMPUTE_UNUSED(alpha_f16); - execute_window_loop(win_out, [&](const Coordinates & id) + execute_window_loop(win_out, [&](const Coordinates &) { - if(id.x() > width_matrix_b) + int x = window_start_x; + // Here we don't check for x lower equal than (window_end_x - window_step_x) because of + // window_end_x is computed above which may cause out-of-bound writes to the output. + for(; x < (window_end_x - window_step_x); x += window_step_x) { - return; - } + if(x > width_matrix_b) + { + return; + } - float16x8_t acc0 = vdupq_n_f16(0.f); - float16x8_t acc1 = vdupq_n_f16(0.f); - float16x8_t acc2 = vdupq_n_f16(0.f); - float16x8_t acc3 = vdupq_n_f16(0.f); + auto matrix_b = reinterpret_cast(inb.ptr()) + x; - auto vec_a = reinterpret_cast(ina.ptr()); - auto matrix_b = reinterpret_cast(inb.ptr()); + float16x8_t acc0 = vdupq_n_f16(0.f); + float16x8_t acc1 = vdupq_n_f16(0.f); + float16x8_t acc2 = vdupq_n_f16(0.f); + float16x8_t acc3 = vdupq_n_f16(0.f); - const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a; - for(; vec_a <= (vec_a_end_addr - 4);) - { - const float16x4_t a0l = vld1_f16(vec_a); - - float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride); - float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride); - float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride); - float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride); - float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride); - float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride); - float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride); - float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride); - - acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 0)); - acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 0)); - acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 0)); - acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 0)); - acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 1)); - acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 1)); - acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 1)); - acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 1)); - - matrix_b += 2 * in_b_stride; - - b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride); - b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride); - b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride); - b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride); - b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride); - b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride); - b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride); - b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride); - - acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 2)); - acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 2)); - acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 2)); - acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 2)); - acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 3)); - acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 3)); - acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 3)); - acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 3)); - - vec_a += 4; - matrix_b += 2 * in_b_stride; - } + auto vec_a = reinterpret_cast(ina.ptr()); + const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a; + for(; vec_a <= (vec_a_end_addr - 4);) + { + const float16x4_t a0l = vld1_f16(vec_a); + + float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride); + float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride); + float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride); + float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride); + float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride); + float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride); + float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride); + float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride); + + acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 0)); + acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 0)); + acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 0)); + acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 0)); + acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 1)); + acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 1)); + acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 1)); + acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 1)); + + matrix_b += 2 * in_b_stride; + + b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride); + b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride); + b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride); + b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride); + b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride); + b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride); + b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride); + b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride); + + acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 2)); + acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 2)); + acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 2)); + acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 2)); + acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 3)); + acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 3)); + acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 3)); + acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 3)); + + vec_a += 4; + matrix_b += 2 * in_b_stride; + } - for(; vec_a < vec_a_end_addr;) - { - const float16_t a0 = *vec_a; - const float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride); - const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride); - const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride); - const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride); - - acc0 = vaddq_f16(acc0, vmulq_n_f16(b00, a0)); - acc1 = vaddq_f16(acc1, vmulq_n_f16(b01, a0)); - acc2 = vaddq_f16(acc2, vmulq_n_f16(b02, a0)); - acc3 = vaddq_f16(acc3, vmulq_n_f16(b03, a0)); - - vec_a += 1; - matrix_b += in_b_stride; + for(; vec_a < vec_a_end_addr; ++vec_a) + { + const float16_t a0 = *vec_a; + const float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride); + const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride); + const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride); + const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride); + + acc0 = vaddq_f16(acc0, vmulq_n_f16(b00, a0)); + acc1 = vaddq_f16(acc1, vmulq_n_f16(b01, a0)); + acc2 = vaddq_f16(acc2, vmulq_n_f16(b02, a0)); + acc3 = vaddq_f16(acc3, vmulq_n_f16(b03, a0)); + + matrix_b += in_b_stride; + } + + // Multiply by the weight of matrix product (alpha) + if(multiply_alpha) + { + acc0 = vmulq_f16(acc0, alpha_f16); + acc1 = vmulq_f16(acc1, alpha_f16); + acc2 = vmulq_f16(acc2, alpha_f16); + acc3 = vmulq_f16(acc3, alpha_f16); + } + + auto vec_out = reinterpret_cast(out.ptr()) + x; + + vst1q_f16(vec_out + 0, acc0); + vst1q_f16(vec_out + 8, acc1); + vst1q_f16(vec_out + 16, acc2); + vst1q_f16(vec_out + 24, acc3); } - // Multiply by the weight of matrix product (alpha) - if(multiply_alpha) + for(; x < window_end_x; ++x) { - acc0 = vmulq_f16(acc0, alpha_f16); - acc1 = vmulq_f16(acc1, alpha_f16); - acc2 = vmulq_f16(acc2, alpha_f16); - acc3 = vmulq_f16(acc3, alpha_f16); - } + if(x > width_matrix_b) + { + return; + } + + auto matrix_b = reinterpret_cast(inb.ptr()) + x; - const auto vec_out = reinterpret_cast(out.ptr()); + float16x4_t vacc = vdup_n_f16(0.f); + + auto vec_a = reinterpret_cast(ina.ptr()); + const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a; + for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4) + { + const float16x4_t a0l = vld1_f16(vec_a); - vst1q_f16(vec_out + 0, acc0); - vst1q_f16(vec_out + 8, acc1); - vst1q_f16(vec_out + 16, acc2); - vst1q_f16(vec_out + 24, acc3); + const float16x4_t b_col = + { + *(matrix_b + 0 * in_b_stride), + *(matrix_b + 1 * in_b_stride), + *(matrix_b + 2 * in_b_stride), + *(matrix_b + 3 * in_b_stride), + }; + + vacc = vadd_f16(vacc, vmul_f16(a0l, b_col)); + + matrix_b += 4 * in_b_stride; + } + + float16_t acc = vget_lane_f16(vacc, 0) + vget_lane_f16(vacc, 1) + vget_lane_f16(vacc, 2) + vget_lane_f16(vacc, 3); + + for(; vec_a < vec_a_end_addr; ++vec_a) + { + const float16_t a0 = *vec_a; + const float16_t b00 = *matrix_b; + + acc += b00 * a0; + + matrix_b += in_b_stride; + } + // Multiply by the weight of matrix product (alpha) + if(multiply_alpha) + { + acc *= static_cast(alpha); + } + + auto vec_out = reinterpret_cast(out.ptr()) + x; + + *(vec_out) = acc; + } }, ina, inb, out); -#else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - ARM_COMPUTE_UNUSED(input0); - ARM_COMPUTE_UNUSED(input1); - ARM_COMPUTE_UNUSED(output); - ARM_COMPUTE_UNUSED(window); - ARM_COMPUTE_UNUSED(info); - ARM_COMPUTE_UNUSED(alpha); - ARM_COMPUTE_ERROR("Not implemented"); -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ } +#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ -template void vector_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, const ThreadInfo &info, float alpha) { const auto width_matrix_b = static_cast(output->info()->dimension(0)); @@ -213,7 +251,7 @@ void vector_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, IT const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x; Window win_out(window); - win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x)); + win_out.set(Window::DimX, Window::Dimension(0, 1, 1)); win_out.set(Window::DimY, Window::Dimension(0, 1, 1)); Window win_a(window); @@ -227,137 +265,215 @@ void vector_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, IT { win_b = window; } - win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x)); + win_b.set(Window::DimX, Window::Dimension(0, 1, 1)); win_b.set(Window::DimY, Window::Dimension(0, 1, 1)); Iterator ina(input0, win_a); Iterator inb(input1, win_b); Iterator out(output, win_out); - execute_window_loop(win_out, [&](const Coordinates & id) + const bool multiply_alpha = !(helpers::float_ops::is_one(alpha)); + + const float32x4_t alpha_f32 = vdupq_n_f32(alpha); + + execute_window_loop(win_out, [&](const Coordinates &) { - if(id.x() > width_matrix_b) + int x = window_start_x; + // Here we don't check for x lower equal than (window_end_x - window_step_x) because of + // window_end_x is computed above which may cause out-of-bound writes to the output. + for(; x < (window_end_x - window_step_x); x += window_step_x) { - return; - } + if(x > width_matrix_b) + { + return; + } - float32x4_t acc0 = vdupq_n_f32(0.f); - float32x4_t acc1 = vdupq_n_f32(0.f); - float32x4_t acc2 = vdupq_n_f32(0.f); - float32x4_t acc3 = vdupq_n_f32(0.f); + float32x4_t acc0 = vdupq_n_f32(0.f); + float32x4_t acc1 = vdupq_n_f32(0.f); + float32x4_t acc2 = vdupq_n_f32(0.f); + float32x4_t acc3 = vdupq_n_f32(0.f); - auto vec_a = reinterpret_cast(ina.ptr()); - auto matrix_b = reinterpret_cast(inb.ptr()); + auto vec_a = reinterpret_cast(ina.ptr()); + auto matrix_b = reinterpret_cast(inb.ptr()) + x; #if __arm__ - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(vec_a))); - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(matrix_b))); - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(matrix_b + in_b_stride))); + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(vec_a))); + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(matrix_b))); + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(matrix_b + in_b_stride))); #endif /* __arm__ */ - auto vec_a_end_addr = vec_a + num_elems_vec_a; - for(; vec_a <= (vec_a_end_addr - 4);) - { - float32x2_t a0l = vld1_f32(vec_a); + auto vec_a_end_addr = vec_a + num_elems_vec_a; + for(; vec_a <= (vec_a_end_addr - 4);) + { + float32x2_t a0l = vld1_f32(vec_a); - float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride); - float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride); - float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride); - float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride); + float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride); + float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride); + float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride); + float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride); - float32x4_t b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride); - float32x4_t b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride); - float32x4_t b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride); - float32x4_t b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride); + float32x4_t b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride); + float32x4_t b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride); + float32x4_t b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride); + float32x4_t b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride); #if __arm__ - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(vec_a))); - asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 1 * in_b_stride))); - asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 2 * in_b_stride))); - asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 3 * in_b_stride))); - asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 4 * in_b_stride))); + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(vec_a))); + asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 1 * in_b_stride))); + asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 2 * in_b_stride))); + asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 3 * in_b_stride))); + asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 4 * in_b_stride))); #endif /* __arm__ */ - acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0); - acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0); - acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0); - acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0); + acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0); + acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0); + acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0); + acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0); - acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1); - acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1); - acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1); - acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1); + acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1); + acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1); + acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1); + acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1); - vec_a += 2; - matrix_b += 2 * in_b_stride; + vec_a += 2; + matrix_b += 2 * in_b_stride; - a0l = vld1_f32(vec_a); + a0l = vld1_f32(vec_a); - b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride); - b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride); - b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride); - b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride); + b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride); + b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride); + b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride); + b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride); - b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride); - b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride); - b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride); - b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride); + b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride); + b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride); + b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride); + b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride); - acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0); - acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0); - acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0); - acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0); + acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0); + acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0); + acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0); + acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0); - acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1); - acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1); - acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1); - acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1); + acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1); + acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1); + acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1); + acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1); - vec_a += 2; - matrix_b += 2 * in_b_stride; - } + vec_a += 2; + matrix_b += 2 * in_b_stride; + } - for(; vec_a < vec_a_end_addr;) - { - const float a0 = *vec_a; + for(; vec_a < vec_a_end_addr; ++vec_a) + { + const float a0 = *vec_a; + + const float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride); + const float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride); + const float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride); + const float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride); - const float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride); - const float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride); - const float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride); - const float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride); + acc0 = vmlaq_n_f32(acc0, b00, a0); + acc1 = vmlaq_n_f32(acc1, b01, a0); + acc2 = vmlaq_n_f32(acc2, b02, a0); + acc3 = vmlaq_n_f32(acc3, b03, a0); - acc0 = vmlaq_n_f32(acc0, b00, a0); - acc1 = vmlaq_n_f32(acc1, b01, a0); - acc2 = vmlaq_n_f32(acc2, b02, a0); - acc3 = vmlaq_n_f32(acc3, b03, a0); + matrix_b += in_b_stride; + } - vec_a += 1; - matrix_b += in_b_stride; + // Multiply by the weight of matrix product (alpha) + if(multiply_alpha) + { + acc0 = vmulq_f32(acc0, alpha_f32); + acc1 = vmulq_f32(acc1, alpha_f32); + acc2 = vmulq_f32(acc2, alpha_f32); + acc3 = vmulq_f32(acc3, alpha_f32); + } + + const auto vec_out = reinterpret_cast(out.ptr()) + x; + + vst1q_f32(vec_out + 0, acc0); + vst1q_f32(vec_out + 4, acc1); + vst1q_f32(vec_out + 8, acc2); + vst1q_f32(vec_out + 12, acc3); } - // Multiply by the weight of matrix product (alpha) - if(multiply_alpha) + // Left-over loop + for(; x < window_end_x; ++x) { - const float32x4_t alpha_f32 = vdupq_n_f32(alpha); - acc0 = vmulq_f32(acc0, alpha_f32); - acc1 = vmulq_f32(acc1, alpha_f32); - acc2 = vmulq_f32(acc2, alpha_f32); - acc3 = vmulq_f32(acc3, alpha_f32); - } + if(x > width_matrix_b) + { + return; + } + + float32x4_t vacc = vdupq_n_f32(0.f); + + auto vec_a = reinterpret_cast(ina.ptr()); + auto matrix_b = reinterpret_cast(inb.ptr()) + x; + +#if __arm__ + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(vec_a))); + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(matrix_b))); + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(matrix_b + in_b_stride))); +#endif /* __arm__ */ + + auto vec_a_end_addr = vec_a + num_elems_vec_a; + for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4) + { + const float32x4_t a0l = vld1q_f32(vec_a); + + const float32x4_t b_col = + { + *(matrix_b + 0 * in_b_stride), + *(matrix_b + 1 * in_b_stride), + *(matrix_b + 2 * in_b_stride), + *(matrix_b + 3 * in_b_stride), + }; + +#if __arm__ + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(vec_a))); + asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 1 * in_b_stride))); + asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 2 * in_b_stride))); + asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 3 * in_b_stride))); + asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 4 * in_b_stride))); +#endif /* __arm__ */ + + vacc = vmlaq_f32(vacc, b_col, a0l); + + matrix_b += 4 * in_b_stride; + } + + float acc = vgetq_lane_f32(vacc, 0) + vgetq_lane_f32(vacc, 1) + vgetq_lane_f32(vacc, 2) + vgetq_lane_f32(vacc, 3); + + for(; vec_a < vec_a_end_addr; ++vec_a) + { + const float a0 = *vec_a; - const auto vec_out = reinterpret_cast(out.ptr()); + const float b00 = *matrix_b; - vst1q_f32(vec_out + 0, acc0); - vst1q_f32(vec_out + 4, acc1); - vst1q_f32(vec_out + 8, acc2); - vst1q_f32(vec_out + 12, acc3); + acc += b00 * a0; + + matrix_b += in_b_stride; + } + + // Multiply by the weight of matrix product (alpha) + if(multiply_alpha) + { + acc *= alpha; + } + + const auto vec_out = reinterpret_cast(out.ptr()) + x; + + *vec_out = acc; + } }, ina, inb, out); } -template void matrix_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha) { + const int out_width = static_cast(output->info()->dimension(0)); + const int out_height = static_cast(output->info()->dimension(1)); const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()); const size_t out_stride1 = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type()); const size_t out_stride2 = out_stride1 * 2; @@ -385,10 +501,14 @@ void matrix_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, IT Iterator inb(input1, win_b); Iterator out(output, window); + const bool multiply_alpha = !(helpers::float_ops::is_one(alpha)); + + const float32x4_t alpha_f32 = vdupq_n_f32(alpha); + // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration // All the values needed for computing a single 4x4 block will be read from consecutive memory positions - execute_window_loop(window, [&](const Coordinates &) + execute_window_loop(window, [&](const Coordinates & id) { auto mtx_a0 = reinterpret_cast(ina.ptr()); auto mtx_b0 = reinterpret_cast(inb.ptr()); @@ -630,37 +750,103 @@ void matrix_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, IT // Multiply by the weight of matrix product (alpha) if(multiply_alpha) { - const float32x4_t alpha_f32 = vdupq_n_f32(alpha); - acc00 = vmulq_f32(acc00, alpha_f32); - acc10 = vmulq_f32(acc10, alpha_f32); - acc20 = vmulq_f32(acc20, alpha_f32); - acc30 = vmulq_f32(acc30, alpha_f32); - acc01 = vmulq_f32(acc01, alpha_f32); - acc11 = vmulq_f32(acc11, alpha_f32); - acc21 = vmulq_f32(acc21, alpha_f32); - acc31 = vmulq_f32(acc31, alpha_f32); + acc00 = vmulq_f32(acc00, alpha_f32); + acc10 = vmulq_f32(acc10, alpha_f32); + acc20 = vmulq_f32(acc20, alpha_f32); + acc30 = vmulq_f32(acc30, alpha_f32); + acc01 = vmulq_f32(acc01, alpha_f32); + acc11 = vmulq_f32(acc11, alpha_f32); + acc21 = vmulq_f32(acc21, alpha_f32); + acc31 = vmulq_f32(acc31, alpha_f32); } const auto mtx_out0 = reinterpret_cast(out.ptr()); const auto mtx_out1 = mtx_out0 + 4; - // Store the 4 blocks - vst1q_f32(mtx_out0, acc00); - vst1q_f32(mtx_out1, acc01); - vst1q_f32(mtx_out0 + out_stride1, acc10); - vst1q_f32(mtx_out1 + out_stride1, acc11); - vst1q_f32(mtx_out0 + out_stride2, acc20); - vst1q_f32(mtx_out1 + out_stride2, acc21); - vst1q_f32(mtx_out0 + out_stride3, acc30); - vst1q_f32(mtx_out1 + out_stride3, acc31); + if(id.x() < (out_width - 8)) + { + vst1q_f32(mtx_out0, acc00); + vst1q_f32(mtx_out1, acc01); + if(id.y() + 1 < out_height) + { + vst1q_f32(mtx_out0 + out_stride1, acc10); + vst1q_f32(mtx_out1 + out_stride1, acc11); + if(id.y() + 2 < out_height) + { + vst1q_f32(mtx_out0 + out_stride2, acc20); + vst1q_f32(mtx_out1 + out_stride2, acc21); + if(id.y() + 3 < out_height) + { + vst1q_f32(mtx_out0 + out_stride3, acc30); + vst1q_f32(mtx_out1 + out_stride3, acc31); + } + } + } + } + else if(id.x() < (out_width - 4)) + { + vst1q_f32(mtx_out0, acc00); + if(id.y() + 1 < out_height) + { + vst1q_f32(mtx_out0 + out_stride1, acc10); + if(id.y() + 2 < out_height) + { + vst1q_f32(mtx_out0 + out_stride2, acc20); + if(id.y() + 3 < out_height) + { + vst1q_f32(mtx_out0 + out_stride3, acc30); + } + } + } + // Left-over columns + const int columns_left = out_width - id.x() - 4; + for(auto x = 0; x < columns_left; ++x) + { + *(mtx_out1 + x) = acc01[x]; + if(id.y() + 1 < out_height) + { + *(mtx_out1 + x + out_stride1) = acc11[x]; + if(id.y() + 2 < out_height) + { + *(mtx_out1 + x + out_stride2) = acc21[x]; + if(id.y() + 3 < out_height) + { + *(mtx_out1 + x + out_stride3) = acc31[x]; + } + } + } + } + } + else + { + // Left-over columns + const int columns_left = out_width - id.x(); + for(int x = 0; x < columns_left; ++x) + { + *(mtx_out0 + x) = acc00[x]; + if(id.y() + 1 < out_height) + { + *(mtx_out0 + x + out_stride1) = acc10[x]; + if(id.y() + 2 < out_height) + { + *(mtx_out0 + x + out_stride2) = acc20[x]; + if(id.y() + 3 < out_height) + { + *(mtx_out0 + x + out_stride3) = acc30[x]; + } + } + } + } + } }, ina, inb, out); } -template +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC void matrix_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha) { -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + const int out_width = static_cast(output->info()->dimension(0)); + const int out_height = static_cast(output->info()->dimension(1)); const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()); const size_t out_stride = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type()); const int num_elems_matrix_b_x = input1->info()->dimension(0); @@ -685,9 +871,11 @@ void matrix_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, IT Iterator inb(input1, win_b); Iterator out(output, window); + const bool multiply_alpha = !(helpers::float_ops::is_one(alpha)); + const float16x8_t alpha_f16 = vdupq_n_f16(alpha); - execute_window_loop(window, [&](const Coordinates &) + execute_window_loop(window, [&](const Coordinates & id) { const auto *mtx_a0 = reinterpret_cast(ina.ptr()); const auto *mtx_b0 = reinterpret_cast(inb.ptr()); @@ -790,21 +978,47 @@ void matrix_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, IT c.val[3] = vmulq_f16(c.val[3], alpha_f16); } - vst1q_f16(mtx_out + 0 * out_stride, c.val[0]); - vst1q_f16(mtx_out + 1 * out_stride, c.val[1]); - vst1q_f16(mtx_out + 2 * out_stride, c.val[2]); - vst1q_f16(mtx_out + 3 * out_stride, c.val[3]); + if(id.x() < (out_width - 8)) + { + vst1q_f16(mtx_out, c.val[0]); + if(id.y() + 1 < out_height) + { + vst1q_f16(mtx_out + 1 * out_stride, c.val[1]); + if(id.y() + 2 < out_height) + { + vst1q_f16(mtx_out + 2 * out_stride, c.val[2]); + if(id.y() + 3 < out_height) + { + vst1q_f16(mtx_out + 3 * out_stride, c.val[3]); + } + } + } + } + else + { + // Left-over columns + const int columns_left = out_width - id.x(); + for(int x = 0; x < columns_left; ++x) + { + *(mtx_out + x) = c.val[0][x]; + if(id.y() + 1 < out_height) + { + *(mtx_out + x + 1 * out_stride) = c.val[1][x]; + if(id.y() + 2 < out_height) + { + *(mtx_out + x + 2 * out_stride) = c.val[2][x]; + if(id.y() + 3 < out_height) + { + *(mtx_out + x + 3 * out_stride) = c.val[3][x]; + } + } + } + } + } }, ina, inb, out); -#else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - ARM_COMPUTE_UNUSED(input0); - ARM_COMPUTE_UNUSED(input1); - ARM_COMPUTE_UNUSED(output); - ARM_COMPUTE_UNUSED(window); - ARM_COMPUTE_UNUSED(alpha); - ARM_COMPUTE_ERROR("Not implemented"); -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ } +#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info) { @@ -866,92 +1080,6 @@ inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *i return Status{}; } - -inline std::pair validate_and_configure_window(ITensorInfo *input0, ITensorInfo *input1, ITensorInfo *output) -{ - bool window_changed{}; - Window win{}; - - unsigned int num_elems_processed_per_iteration_x = 0; - const unsigned int num_elems_processed_per_iteration_y = 4; - - // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication - if((output->dimension(1) == 1)) - { - switch(input0->data_type()) - { - case DataType::F32: - { - num_elems_processed_per_iteration_x = 16; - break; - } -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - case DataType::F16: - { - num_elems_processed_per_iteration_x = 32; - break; - } -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - default: - { - ARM_COMPUTE_ERROR("Data type not supported"); - break; - } - } - - // Configure kernel window - win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x)); - - AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration_x); - - window_changed = update_window_and_padding(win, - AccessWindowStatic(input0, 0, 0, input0->tensor_shape().x(), 1), - AccessWindowHorizontal(input1, 0, num_elems_processed_per_iteration_x), - output_access); - - Coordinates coord; - coord.set_num_dimensions(output->num_dimensions()); - output_access.set_valid_region(win, ValidRegion(coord, output->tensor_shape())); - } - else - { - switch(input0->data_type()) - { - case DataType::F32: - { - num_elems_processed_per_iteration_x = 8; - break; - } -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - case DataType::F16: - { - num_elems_processed_per_iteration_x = 8; - break; - } -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - default: - { - ARM_COMPUTE_ERROR("Data type not supported"); - break; - } - } - - // Configure kernel window - win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y)); - - AccessWindowRectangle output_access(output, 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y); - - window_changed = update_window_and_padding(win, - AccessWindowRectangle(input0, 0, 0, 4, 1, 1.f, 0.25f), - AccessWindowStatic(input1, 0, 0, input1->tensor_shape().x(), ceil_to_multiple(input1->tensor_shape().y(), 4)), - output_access); - - output_access.set_valid_region(win, ValidRegion(Coordinates(0, 0), output->tensor_shape())); - } - - Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{}; - return std::make_pair(err, win); -} } // namespace NEGEMMMatrixMultiplyKernel::NEGEMMMatrixMultiplyKernel() @@ -979,16 +1107,33 @@ void NEGEMMMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor _alpha = alpha; // Configure kernel window - auto win_config = validate_and_configure_window(input0->info(), input1->info(), output->info()); - ARM_COMPUTE_ERROR_THROW_ON(win_config.first); - INEKernel::configure(win_config.second); + Window win{}; + + // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication + if((output->info()->dimension(1) == 1)) + { + const unsigned int num_elems_processed_per_iteration_x = (input0->info()->data_type() == DataType::F32) ? 16 : 32; + + win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x)); + } + else + { + constexpr unsigned int num_elems_processed_per_iteration_x = 8; + constexpr unsigned int num_elems_processed_per_iteration_y = 4; + + win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y)); + } + + Coordinates coord; + coord.set_num_dimensions(output->info()->num_dimensions()); + output->info()->set_valid_region(ValidRegion(coord, output->info()->tensor_shape())); + INEKernel::configure(win); } Status NEGEMMMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info) { ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output, alpha, is_interleaved, reshape_info)); - ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input0->clone().get(), input1->clone().get(), output->clone().get()).first); return Status{}; } @@ -998,57 +1143,29 @@ void NEGEMMMatrixMultiplyKernel::run(const Window &window, const ThreadInfo &inf ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window); - const bool multiply_alpha = !(helpers::float_ops::is_one(_alpha)); - // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication - if((_output->info()->dimension(1) == 1)) + const bool is_output_vector = (_output->info()->dimension(1) == 1); + switch(_input0->info()->data_type()) { - switch(_input0->info()->data_type()) + case DataType::F32: { - case DataType::F32: - { - multiply_alpha ? vector_matrix_multiply_f32(_input0, _input1, _output, window, info, _alpha) : - vector_matrix_multiply_f32(_input0, _input1, _output, window, info, _alpha); - break; - } -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - case DataType::F16: - { - multiply_alpha ? vector_matrix_multiply_f16(_input0, _input1, _output, window, info, _alpha) : - vector_matrix_multiply_f16(_input0, _input1, _output, window, info, _alpha); - break; - } -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - default: - { - ARM_COMPUTE_ERROR("Data type not supported"); - break; - } + is_output_vector ? vector_matrix_multiply_f32(_input0, _input1, _output, window, info, _alpha) : + matrix_matrix_multiply_f32(_input0, _input1, _output, window, _alpha); + break; } - } - else - { - switch(_input0->info()->data_type()) - { - case DataType::F32: - { - multiply_alpha ? matrix_matrix_multiply_f32(_input0, _input1, _output, window, _alpha) : - matrix_matrix_multiply_f32(_input0, _input1, _output, window, _alpha); - break; - } #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - case DataType::F16: - { - multiply_alpha ? matrix_matrix_multiply_f16(_input0, _input1, _output, window, _alpha) : - matrix_matrix_multiply_f16(_input0, _input1, _output, window, _alpha); - break; - } + case DataType::F16: + { + is_output_vector ? vector_matrix_multiply_f16(_input0, _input1, _output, window, info, _alpha) : + matrix_matrix_multiply_f16(_input0, _input1, _output, window, _alpha); + break; + } #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - default: - { - ARM_COMPUTE_ERROR("Data type not supported"); - break; - } + default: + { + ARM_COMPUTE_ERROR("Data type not supported"); + break; } } } +} // namespace arm_compute diff --git a/tests/validation/NEON/GEMM.cpp b/tests/validation/NEON/GEMM.cpp index dfac72f3a5..25e8f28dc3 100644 --- a/tests/validation/NEON/GEMM.cpp +++ b/tests/validation/NEON/GEMM.cpp @@ -87,6 +87,20 @@ bool validate_zero_padding(unsigned int dim0_value, unsigned int dim1_value) return in.info()->padding().empty(); } +/* Zero padding test for GEMM kernels */ +bool validate_gemm_zero_padding(const TensorShape shape0, const TensorShape shape1) +{ + // Create tensors + Tensor in0 = create_tensor(shape0, DataType::F32); + Tensor in1 = create_tensor(shape1, DataType::F32); + Tensor dst; + + // Validate zero-padding + NEGEMMMatrixMultiplyKernel gemm; + gemm.configure(&in0, &in1, &dst, 1.0, false); + + return in0.info()->padding().empty() && in1.info()->padding().empty() && dst.info()->padding().empty(); +} } // namespace TEST_SUITE(NEON) @@ -182,6 +196,26 @@ template using NEGEMMFixtureDisabledC = GEMMValidationFixture; TEST_SUITE(Float) +DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(framework::dataset::make("In0", { TensorShape(21U, 13U), + TensorShape(31U, 1U), + TensorShape(31U, 1U), + TensorShape(8U, 2U), + TensorShape(38U, 12U), + TensorShape(32U, 1U) + }), + framework::dataset::make("In1", { TensorShape(33U, 21U), + TensorShape(23U, 31U), + TensorShape(23U, 31U), + TensorShape(16U, 8U), + TensorShape(21U, 38U), + TensorShape(17U, 32U) + })), + shape0, shape1) +{ + bool status = validate_gemm_zero_padding(shape0, shape1); + ARM_COMPUTE_EXPECT(status, framework::LogLevel::ERRORS); +} + #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC TEST_SUITE(FP16) FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMDataset(), -- cgit v1.2.1