aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2020-10-08 11:54:42 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2020-10-14 09:49:26 +0000
commitcf9e29e3bd2fcd772c156c7866425335bfdbde6a (patch)
treed363f53dec291cf2e93eaf339b6fb1a01e8a9b06
parent87350f47084d2b69daa11c3b1c7eb47e02260063 (diff)
downloadComputeLibrary-cf9e29e3bd2fcd772c156c7866425335bfdbde6a.tar.gz
COMPMID-3172: Remove padding from NEGEMMMatrixMultiplyKernel
Template parameter has been removed, which reduces the binary size by: - ~4 kB for armv8.2a - ~12 kB for armv8a Change-Id: Ib499a18a4980a3ee7b201507b943f900adf20a73 Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4122 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp851
-rw-r--r--tests/validation/NEON/GEMM.cpp34
2 files changed, 518 insertions, 367 deletions
diff --git a/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp b/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp
index 5c5367c9c1..6f74e3fc06 100644
--- a/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp
@@ -23,11 +23,9 @@
*/
#include "arm_compute/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h"
-#include "arm_compute/core/AccessWindowStatic.h"
#include "arm_compute/core/CPP/Validate.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
-#include "arm_compute/core/IAccessWindow.h"
#include "arm_compute/core/ITensor.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
@@ -39,25 +37,16 @@
#include "src/core/NEON/NEFixedPoint.h"
#include <arm_neon.h>
-#include <cstddef>
-#include <cstdint>
-#include <tuple>
-
-using namespace arm_compute;
namespace arm_compute
{
-class Coordinates;
-} // namespace arm_compute
-
namespace
{
-template <bool multiply_alpha>
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
void vector_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, const ThreadInfo &info, float alpha)
{
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
- const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
+ const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / input1->info()->element_size());
const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
// The implementation computes 32 elements per iteration
@@ -67,7 +56,7 @@ void vector_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, IT
ARM_COMPUTE_ERROR_ON_MSG((window_end_x - window_start_x) % window_step_x, " (window_end_x - window_start_x) must be multiple of window_step_x");
Window win_out(window);
- win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
+ win_out.set(Window::DimX, Window::Dimension(0, 1, 1));
win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
Window win_a(window);
@@ -81,125 +70,174 @@ void vector_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, IT
{
win_b = window;
}
- win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
+ win_b.set(Window::DimX, Window::Dimension(0, 1, 1));
win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
Iterator ina(input0, win_a);
Iterator inb(input1, win_b);
Iterator out(output, win_out);
+ const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
+
const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
- ARM_COMPUTE_UNUSED(alpha_f16);
- execute_window_loop(win_out, [&](const Coordinates & id)
+ execute_window_loop(win_out, [&](const Coordinates &)
{
- if(id.x() > width_matrix_b)
+ int x = window_start_x;
+ // Here we don't check for x lower equal than (window_end_x - window_step_x) because of
+ // window_end_x is computed above which may cause out-of-bound writes to the output.
+ for(; x < (window_end_x - window_step_x); x += window_step_x)
{
- return;
- }
+ if(x > width_matrix_b)
+ {
+ return;
+ }
- float16x8_t acc0 = vdupq_n_f16(0.f);
- float16x8_t acc1 = vdupq_n_f16(0.f);
- float16x8_t acc2 = vdupq_n_f16(0.f);
- float16x8_t acc3 = vdupq_n_f16(0.f);
+ auto matrix_b = reinterpret_cast<const float16_t *>(inb.ptr()) + x;
- auto vec_a = reinterpret_cast<const float16_t *>(ina.ptr());
- auto matrix_b = reinterpret_cast<const float16_t *>(inb.ptr());
+ float16x8_t acc0 = vdupq_n_f16(0.f);
+ float16x8_t acc1 = vdupq_n_f16(0.f);
+ float16x8_t acc2 = vdupq_n_f16(0.f);
+ float16x8_t acc3 = vdupq_n_f16(0.f);
- const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
- for(; vec_a <= (vec_a_end_addr - 4);)
- {
- const float16x4_t a0l = vld1_f16(vec_a);
-
- float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
- float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
- float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
- float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
- float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
- float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
- float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
- float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
-
- acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 0));
- acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 0));
- acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 0));
- acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 0));
- acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 1));
- acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 1));
- acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 1));
- acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 1));
-
- matrix_b += 2 * in_b_stride;
-
- b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
- b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
- b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
- b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
- b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
- b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
- b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
- b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
-
- acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 2));
- acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 2));
- acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 2));
- acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 2));
- acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 3));
- acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 3));
- acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 3));
- acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 3));
-
- vec_a += 4;
- matrix_b += 2 * in_b_stride;
- }
+ auto vec_a = reinterpret_cast<const float16_t *>(ina.ptr());
+ const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
+ for(; vec_a <= (vec_a_end_addr - 4);)
+ {
+ const float16x4_t a0l = vld1_f16(vec_a);
+
+ float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
+ float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
+ float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
+ float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
+ float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
+ float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
+ float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
+ float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
+
+ acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 0));
+ acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 0));
+ acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 0));
+ acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 0));
+ acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 1));
+ acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 1));
+ acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 1));
+ acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 1));
+
+ matrix_b += 2 * in_b_stride;
+
+ b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
+ b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
+ b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
+ b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
+ b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
+ b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
+ b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
+ b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
+
+ acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 2));
+ acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 2));
+ acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 2));
+ acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 2));
+ acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 3));
+ acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 3));
+ acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 3));
+ acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 3));
+
+ vec_a += 4;
+ matrix_b += 2 * in_b_stride;
+ }
- for(; vec_a < vec_a_end_addr;)
- {
- const float16_t a0 = *vec_a;
- const float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
- const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
- const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
- const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
-
- acc0 = vaddq_f16(acc0, vmulq_n_f16(b00, a0));
- acc1 = vaddq_f16(acc1, vmulq_n_f16(b01, a0));
- acc2 = vaddq_f16(acc2, vmulq_n_f16(b02, a0));
- acc3 = vaddq_f16(acc3, vmulq_n_f16(b03, a0));
-
- vec_a += 1;
- matrix_b += in_b_stride;
+ for(; vec_a < vec_a_end_addr; ++vec_a)
+ {
+ const float16_t a0 = *vec_a;
+ const float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
+ const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
+ const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
+ const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
+
+ acc0 = vaddq_f16(acc0, vmulq_n_f16(b00, a0));
+ acc1 = vaddq_f16(acc1, vmulq_n_f16(b01, a0));
+ acc2 = vaddq_f16(acc2, vmulq_n_f16(b02, a0));
+ acc3 = vaddq_f16(acc3, vmulq_n_f16(b03, a0));
+
+ matrix_b += in_b_stride;
+ }
+
+ // Multiply by the weight of matrix product (alpha)
+ if(multiply_alpha)
+ {
+ acc0 = vmulq_f16(acc0, alpha_f16);
+ acc1 = vmulq_f16(acc1, alpha_f16);
+ acc2 = vmulq_f16(acc2, alpha_f16);
+ acc3 = vmulq_f16(acc3, alpha_f16);
+ }
+
+ auto vec_out = reinterpret_cast<float16_t *>(out.ptr()) + x;
+
+ vst1q_f16(vec_out + 0, acc0);
+ vst1q_f16(vec_out + 8, acc1);
+ vst1q_f16(vec_out + 16, acc2);
+ vst1q_f16(vec_out + 24, acc3);
}
- // Multiply by the weight of matrix product (alpha)
- if(multiply_alpha)
+ for(; x < window_end_x; ++x)
{
- acc0 = vmulq_f16(acc0, alpha_f16);
- acc1 = vmulq_f16(acc1, alpha_f16);
- acc2 = vmulq_f16(acc2, alpha_f16);
- acc3 = vmulq_f16(acc3, alpha_f16);
- }
+ if(x > width_matrix_b)
+ {
+ return;
+ }
+
+ auto matrix_b = reinterpret_cast<const float16_t *>(inb.ptr()) + x;
- const auto vec_out = reinterpret_cast<float16_t *>(out.ptr());
+ float16x4_t vacc = vdup_n_f16(0.f);
+
+ auto vec_a = reinterpret_cast<const float16_t *>(ina.ptr());
+ const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
+ for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4)
+ {
+ const float16x4_t a0l = vld1_f16(vec_a);
- vst1q_f16(vec_out + 0, acc0);
- vst1q_f16(vec_out + 8, acc1);
- vst1q_f16(vec_out + 16, acc2);
- vst1q_f16(vec_out + 24, acc3);
+ const float16x4_t b_col =
+ {
+ *(matrix_b + 0 * in_b_stride),
+ *(matrix_b + 1 * in_b_stride),
+ *(matrix_b + 2 * in_b_stride),
+ *(matrix_b + 3 * in_b_stride),
+ };
+
+ vacc = vadd_f16(vacc, vmul_f16(a0l, b_col));
+
+ matrix_b += 4 * in_b_stride;
+ }
+
+ float16_t acc = vget_lane_f16(vacc, 0) + vget_lane_f16(vacc, 1) + vget_lane_f16(vacc, 2) + vget_lane_f16(vacc, 3);
+
+ for(; vec_a < vec_a_end_addr; ++vec_a)
+ {
+ const float16_t a0 = *vec_a;
+ const float16_t b00 = *matrix_b;
+
+ acc += b00 * a0;
+
+ matrix_b += in_b_stride;
+ }
+ // Multiply by the weight of matrix product (alpha)
+ if(multiply_alpha)
+ {
+ acc *= static_cast<float16_t>(alpha);
+ }
+
+ auto vec_out = reinterpret_cast<float16_t *>(out.ptr()) + x;
+
+ *(vec_out) = acc;
+ }
},
ina, inb, out);
-#else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
- ARM_COMPUTE_UNUSED(input0);
- ARM_COMPUTE_UNUSED(input1);
- ARM_COMPUTE_UNUSED(output);
- ARM_COMPUTE_UNUSED(window);
- ARM_COMPUTE_UNUSED(info);
- ARM_COMPUTE_UNUSED(alpha);
- ARM_COMPUTE_ERROR("Not implemented");
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
}
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-template <bool multiply_alpha>
void vector_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, const ThreadInfo &info, float alpha)
{
const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
@@ -213,7 +251,7 @@ void vector_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, IT
const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
Window win_out(window);
- win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
+ win_out.set(Window::DimX, Window::Dimension(0, 1, 1));
win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
Window win_a(window);
@@ -227,137 +265,215 @@ void vector_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, IT
{
win_b = window;
}
- win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
+ win_b.set(Window::DimX, Window::Dimension(0, 1, 1));
win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
Iterator ina(input0, win_a);
Iterator inb(input1, win_b);
Iterator out(output, win_out);
- execute_window_loop(win_out, [&](const Coordinates & id)
+ const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
+
+ const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
+
+ execute_window_loop(win_out, [&](const Coordinates &)
{
- if(id.x() > width_matrix_b)
+ int x = window_start_x;
+ // Here we don't check for x lower equal than (window_end_x - window_step_x) because of
+ // window_end_x is computed above which may cause out-of-bound writes to the output.
+ for(; x < (window_end_x - window_step_x); x += window_step_x)
{
- return;
- }
+ if(x > width_matrix_b)
+ {
+ return;
+ }
- float32x4_t acc0 = vdupq_n_f32(0.f);
- float32x4_t acc1 = vdupq_n_f32(0.f);
- float32x4_t acc2 = vdupq_n_f32(0.f);
- float32x4_t acc3 = vdupq_n_f32(0.f);
+ float32x4_t acc0 = vdupq_n_f32(0.f);
+ float32x4_t acc1 = vdupq_n_f32(0.f);
+ float32x4_t acc2 = vdupq_n_f32(0.f);
+ float32x4_t acc3 = vdupq_n_f32(0.f);
- auto vec_a = reinterpret_cast<const float *>(ina.ptr());
- auto matrix_b = reinterpret_cast<const float *>(inb.ptr());
+ auto vec_a = reinterpret_cast<const float *>(ina.ptr());
+ auto matrix_b = reinterpret_cast<const float *>(inb.ptr()) + x;
#if __arm__
- asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
- asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b)));
- asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + in_b_stride)));
+ asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
+ asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b)));
+ asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + in_b_stride)));
#endif /* __arm__ */
- auto vec_a_end_addr = vec_a + num_elems_vec_a;
- for(; vec_a <= (vec_a_end_addr - 4);)
- {
- float32x2_t a0l = vld1_f32(vec_a);
+ auto vec_a_end_addr = vec_a + num_elems_vec_a;
+ for(; vec_a <= (vec_a_end_addr - 4);)
+ {
+ float32x2_t a0l = vld1_f32(vec_a);
- float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
- float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
- float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
- float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
+ float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
+ float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
+ float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
+ float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
- float32x4_t b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
- float32x4_t b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
- float32x4_t b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
- float32x4_t b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
+ float32x4_t b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
+ float32x4_t b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
+ float32x4_t b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
+ float32x4_t b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
#if __arm__
- asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
- asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 1 * in_b_stride)));
- asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 2 * in_b_stride)));
- asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 3 * in_b_stride)));
- asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 4 * in_b_stride)));
+ asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
+ asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 1 * in_b_stride)));
+ asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 2 * in_b_stride)));
+ asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 3 * in_b_stride)));
+ asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 4 * in_b_stride)));
#endif /* __arm__ */
- acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
- acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
- acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
- acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
+ acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
+ acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
+ acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
+ acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
- acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
- acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
- acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
- acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
+ acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
+ acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
+ acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
+ acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
- vec_a += 2;
- matrix_b += 2 * in_b_stride;
+ vec_a += 2;
+ matrix_b += 2 * in_b_stride;
- a0l = vld1_f32(vec_a);
+ a0l = vld1_f32(vec_a);
- b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
- b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
- b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
- b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
+ b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
+ b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
+ b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
+ b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
- b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
- b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
- b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
- b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
+ b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
+ b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
+ b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
+ b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
- acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
- acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
- acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
- acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
+ acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
+ acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
+ acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
+ acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
- acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
- acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
- acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
- acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
+ acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
+ acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
+ acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
+ acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
- vec_a += 2;
- matrix_b += 2 * in_b_stride;
- }
+ vec_a += 2;
+ matrix_b += 2 * in_b_stride;
+ }
- for(; vec_a < vec_a_end_addr;)
- {
- const float a0 = *vec_a;
+ for(; vec_a < vec_a_end_addr; ++vec_a)
+ {
+ const float a0 = *vec_a;
+
+ const float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
+ const float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
+ const float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
+ const float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
- const float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
- const float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
- const float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
- const float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
+ acc0 = vmlaq_n_f32(acc0, b00, a0);
+ acc1 = vmlaq_n_f32(acc1, b01, a0);
+ acc2 = vmlaq_n_f32(acc2, b02, a0);
+ acc3 = vmlaq_n_f32(acc3, b03, a0);
- acc0 = vmlaq_n_f32(acc0, b00, a0);
- acc1 = vmlaq_n_f32(acc1, b01, a0);
- acc2 = vmlaq_n_f32(acc2, b02, a0);
- acc3 = vmlaq_n_f32(acc3, b03, a0);
+ matrix_b += in_b_stride;
+ }
- vec_a += 1;
- matrix_b += in_b_stride;
+ // Multiply by the weight of matrix product (alpha)
+ if(multiply_alpha)
+ {
+ acc0 = vmulq_f32(acc0, alpha_f32);
+ acc1 = vmulq_f32(acc1, alpha_f32);
+ acc2 = vmulq_f32(acc2, alpha_f32);
+ acc3 = vmulq_f32(acc3, alpha_f32);
+ }
+
+ const auto vec_out = reinterpret_cast<float *>(out.ptr()) + x;
+
+ vst1q_f32(vec_out + 0, acc0);
+ vst1q_f32(vec_out + 4, acc1);
+ vst1q_f32(vec_out + 8, acc2);
+ vst1q_f32(vec_out + 12, acc3);
}
- // Multiply by the weight of matrix product (alpha)
- if(multiply_alpha)
+ // Left-over loop
+ for(; x < window_end_x; ++x)
{
- const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
- acc0 = vmulq_f32(acc0, alpha_f32);
- acc1 = vmulq_f32(acc1, alpha_f32);
- acc2 = vmulq_f32(acc2, alpha_f32);
- acc3 = vmulq_f32(acc3, alpha_f32);
- }
+ if(x > width_matrix_b)
+ {
+ return;
+ }
+
+ float32x4_t vacc = vdupq_n_f32(0.f);
+
+ auto vec_a = reinterpret_cast<const float *>(ina.ptr());
+ auto matrix_b = reinterpret_cast<const float *>(inb.ptr()) + x;
+
+#if __arm__
+ asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
+ asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b)));
+ asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + in_b_stride)));
+#endif /* __arm__ */
+
+ auto vec_a_end_addr = vec_a + num_elems_vec_a;
+ for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4)
+ {
+ const float32x4_t a0l = vld1q_f32(vec_a);
+
+ const float32x4_t b_col =
+ {
+ *(matrix_b + 0 * in_b_stride),
+ *(matrix_b + 1 * in_b_stride),
+ *(matrix_b + 2 * in_b_stride),
+ *(matrix_b + 3 * in_b_stride),
+ };
+
+#if __arm__
+ asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
+ asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 1 * in_b_stride)));
+ asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 2 * in_b_stride)));
+ asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 3 * in_b_stride)));
+ asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 4 * in_b_stride)));
+#endif /* __arm__ */
+
+ vacc = vmlaq_f32(vacc, b_col, a0l);
+
+ matrix_b += 4 * in_b_stride;
+ }
+
+ float acc = vgetq_lane_f32(vacc, 0) + vgetq_lane_f32(vacc, 1) + vgetq_lane_f32(vacc, 2) + vgetq_lane_f32(vacc, 3);
+
+ for(; vec_a < vec_a_end_addr; ++vec_a)
+ {
+ const float a0 = *vec_a;
- const auto vec_out = reinterpret_cast<float *>(out.ptr());
+ const float b00 = *matrix_b;
- vst1q_f32(vec_out + 0, acc0);
- vst1q_f32(vec_out + 4, acc1);
- vst1q_f32(vec_out + 8, acc2);
- vst1q_f32(vec_out + 12, acc3);
+ acc += b00 * a0;
+
+ matrix_b += in_b_stride;
+ }
+
+ // Multiply by the weight of matrix product (alpha)
+ if(multiply_alpha)
+ {
+ acc *= alpha;
+ }
+
+ const auto vec_out = reinterpret_cast<float *>(out.ptr()) + x;
+
+ *vec_out = acc;
+ }
},
ina, inb, out);
}
-template <bool multiply_alpha>
void matrix_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
{
+ const int out_width = static_cast<int>(output->info()->dimension(0));
+ const int out_height = static_cast<int>(output->info()->dimension(1));
const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
const size_t out_stride1 = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
const size_t out_stride2 = out_stride1 * 2;
@@ -385,10 +501,14 @@ void matrix_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, IT
Iterator inb(input1, win_b);
Iterator out(output, window);
+ const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
+
+ const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
+
// The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
// The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
// All the values needed for computing a single 4x4 block will be read from consecutive memory positions
- execute_window_loop(window, [&](const Coordinates &)
+ execute_window_loop(window, [&](const Coordinates & id)
{
auto mtx_a0 = reinterpret_cast<const float *>(ina.ptr());
auto mtx_b0 = reinterpret_cast<const float *>(inb.ptr());
@@ -630,37 +750,103 @@ void matrix_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, IT
// Multiply by the weight of matrix product (alpha)
if(multiply_alpha)
{
- const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
- acc00 = vmulq_f32(acc00, alpha_f32);
- acc10 = vmulq_f32(acc10, alpha_f32);
- acc20 = vmulq_f32(acc20, alpha_f32);
- acc30 = vmulq_f32(acc30, alpha_f32);
- acc01 = vmulq_f32(acc01, alpha_f32);
- acc11 = vmulq_f32(acc11, alpha_f32);
- acc21 = vmulq_f32(acc21, alpha_f32);
- acc31 = vmulq_f32(acc31, alpha_f32);
+ acc00 = vmulq_f32(acc00, alpha_f32);
+ acc10 = vmulq_f32(acc10, alpha_f32);
+ acc20 = vmulq_f32(acc20, alpha_f32);
+ acc30 = vmulq_f32(acc30, alpha_f32);
+ acc01 = vmulq_f32(acc01, alpha_f32);
+ acc11 = vmulq_f32(acc11, alpha_f32);
+ acc21 = vmulq_f32(acc21, alpha_f32);
+ acc31 = vmulq_f32(acc31, alpha_f32);
}
const auto mtx_out0 = reinterpret_cast<float *>(out.ptr());
const auto mtx_out1 = mtx_out0 + 4;
- // Store the 4 blocks
- vst1q_f32(mtx_out0, acc00);
- vst1q_f32(mtx_out1, acc01);
- vst1q_f32(mtx_out0 + out_stride1, acc10);
- vst1q_f32(mtx_out1 + out_stride1, acc11);
- vst1q_f32(mtx_out0 + out_stride2, acc20);
- vst1q_f32(mtx_out1 + out_stride2, acc21);
- vst1q_f32(mtx_out0 + out_stride3, acc30);
- vst1q_f32(mtx_out1 + out_stride3, acc31);
+ if(id.x() < (out_width - 8))
+ {
+ vst1q_f32(mtx_out0, acc00);
+ vst1q_f32(mtx_out1, acc01);
+ if(id.y() + 1 < out_height)
+ {
+ vst1q_f32(mtx_out0 + out_stride1, acc10);
+ vst1q_f32(mtx_out1 + out_stride1, acc11);
+ if(id.y() + 2 < out_height)
+ {
+ vst1q_f32(mtx_out0 + out_stride2, acc20);
+ vst1q_f32(mtx_out1 + out_stride2, acc21);
+ if(id.y() + 3 < out_height)
+ {
+ vst1q_f32(mtx_out0 + out_stride3, acc30);
+ vst1q_f32(mtx_out1 + out_stride3, acc31);
+ }
+ }
+ }
+ }
+ else if(id.x() < (out_width - 4))
+ {
+ vst1q_f32(mtx_out0, acc00);
+ if(id.y() + 1 < out_height)
+ {
+ vst1q_f32(mtx_out0 + out_stride1, acc10);
+ if(id.y() + 2 < out_height)
+ {
+ vst1q_f32(mtx_out0 + out_stride2, acc20);
+ if(id.y() + 3 < out_height)
+ {
+ vst1q_f32(mtx_out0 + out_stride3, acc30);
+ }
+ }
+ }
+ // Left-over columns
+ const int columns_left = out_width - id.x() - 4;
+ for(auto x = 0; x < columns_left; ++x)
+ {
+ *(mtx_out1 + x) = acc01[x];
+ if(id.y() + 1 < out_height)
+ {
+ *(mtx_out1 + x + out_stride1) = acc11[x];
+ if(id.y() + 2 < out_height)
+ {
+ *(mtx_out1 + x + out_stride2) = acc21[x];
+ if(id.y() + 3 < out_height)
+ {
+ *(mtx_out1 + x + out_stride3) = acc31[x];
+ }
+ }
+ }
+ }
+ }
+ else
+ {
+ // Left-over columns
+ const int columns_left = out_width - id.x();
+ for(int x = 0; x < columns_left; ++x)
+ {
+ *(mtx_out0 + x) = acc00[x];
+ if(id.y() + 1 < out_height)
+ {
+ *(mtx_out0 + x + out_stride1) = acc10[x];
+ if(id.y() + 2 < out_height)
+ {
+ *(mtx_out0 + x + out_stride2) = acc20[x];
+ if(id.y() + 3 < out_height)
+ {
+ *(mtx_out0 + x + out_stride3) = acc30[x];
+ }
+ }
+ }
+ }
+ }
},
ina, inb, out);
}
-template <bool multiply_alpha>
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
void matrix_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
{
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ const int out_width = static_cast<int>(output->info()->dimension(0));
+ const int out_height = static_cast<int>(output->info()->dimension(1));
const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
const size_t out_stride = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
const int num_elems_matrix_b_x = input1->info()->dimension(0);
@@ -685,9 +871,11 @@ void matrix_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, IT
Iterator inb(input1, win_b);
Iterator out(output, window);
+ const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
+
const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
- execute_window_loop(window, [&](const Coordinates &)
+ execute_window_loop(window, [&](const Coordinates & id)
{
const auto *mtx_a0 = reinterpret_cast<const float16_t *>(ina.ptr());
const auto *mtx_b0 = reinterpret_cast<const float16_t *>(inb.ptr());
@@ -790,21 +978,47 @@ void matrix_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, IT
c.val[3] = vmulq_f16(c.val[3], alpha_f16);
}
- vst1q_f16(mtx_out + 0 * out_stride, c.val[0]);
- vst1q_f16(mtx_out + 1 * out_stride, c.val[1]);
- vst1q_f16(mtx_out + 2 * out_stride, c.val[2]);
- vst1q_f16(mtx_out + 3 * out_stride, c.val[3]);
+ if(id.x() < (out_width - 8))
+ {
+ vst1q_f16(mtx_out, c.val[0]);
+ if(id.y() + 1 < out_height)
+ {
+ vst1q_f16(mtx_out + 1 * out_stride, c.val[1]);
+ if(id.y() + 2 < out_height)
+ {
+ vst1q_f16(mtx_out + 2 * out_stride, c.val[2]);
+ if(id.y() + 3 < out_height)
+ {
+ vst1q_f16(mtx_out + 3 * out_stride, c.val[3]);
+ }
+ }
+ }
+ }
+ else
+ {
+ // Left-over columns
+ const int columns_left = out_width - id.x();
+ for(int x = 0; x < columns_left; ++x)
+ {
+ *(mtx_out + x) = c.val[0][x];
+ if(id.y() + 1 < out_height)
+ {
+ *(mtx_out + x + 1 * out_stride) = c.val[1][x];
+ if(id.y() + 2 < out_height)
+ {
+ *(mtx_out + x + 2 * out_stride) = c.val[2][x];
+ if(id.y() + 3 < out_height)
+ {
+ *(mtx_out + x + 3 * out_stride) = c.val[3][x];
+ }
+ }
+ }
+ }
+ }
},
ina, inb, out);
-#else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
- ARM_COMPUTE_UNUSED(input0);
- ARM_COMPUTE_UNUSED(input1);
- ARM_COMPUTE_UNUSED(output);
- ARM_COMPUTE_UNUSED(window);
- ARM_COMPUTE_UNUSED(alpha);
- ARM_COMPUTE_ERROR("Not implemented");
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
}
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info)
{
@@ -866,92 +1080,6 @@ inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *i
return Status{};
}
-
-inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITensorInfo *input1, ITensorInfo *output)
-{
- bool window_changed{};
- Window win{};
-
- unsigned int num_elems_processed_per_iteration_x = 0;
- const unsigned int num_elems_processed_per_iteration_y = 4;
-
- // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
- if((output->dimension(1) == 1))
- {
- switch(input0->data_type())
- {
- case DataType::F32:
- {
- num_elems_processed_per_iteration_x = 16;
- break;
- }
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- case DataType::F16:
- {
- num_elems_processed_per_iteration_x = 32;
- break;
- }
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
- default:
- {
- ARM_COMPUTE_ERROR("Data type not supported");
- break;
- }
- }
-
- // Configure kernel window
- win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x));
-
- AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration_x);
-
- window_changed = update_window_and_padding(win,
- AccessWindowStatic(input0, 0, 0, input0->tensor_shape().x(), 1),
- AccessWindowHorizontal(input1, 0, num_elems_processed_per_iteration_x),
- output_access);
-
- Coordinates coord;
- coord.set_num_dimensions(output->num_dimensions());
- output_access.set_valid_region(win, ValidRegion(coord, output->tensor_shape()));
- }
- else
- {
- switch(input0->data_type())
- {
- case DataType::F32:
- {
- num_elems_processed_per_iteration_x = 8;
- break;
- }
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- case DataType::F16:
- {
- num_elems_processed_per_iteration_x = 8;
- break;
- }
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
- default:
- {
- ARM_COMPUTE_ERROR("Data type not supported");
- break;
- }
- }
-
- // Configure kernel window
- win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
-
- AccessWindowRectangle output_access(output, 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
-
- window_changed = update_window_and_padding(win,
- AccessWindowRectangle(input0, 0, 0, 4, 1, 1.f, 0.25f),
- AccessWindowStatic(input1, 0, 0, input1->tensor_shape().x(), ceil_to_multiple(input1->tensor_shape().y(), 4)),
- output_access);
-
- output_access.set_valid_region(win, ValidRegion(Coordinates(0, 0), output->tensor_shape()));
- }
-
- Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
- return std::make_pair(err, win);
-}
} // namespace
NEGEMMMatrixMultiplyKernel::NEGEMMMatrixMultiplyKernel()
@@ -979,16 +1107,33 @@ void NEGEMMMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor
_alpha = alpha;
// Configure kernel window
- auto win_config = validate_and_configure_window(input0->info(), input1->info(), output->info());
- ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- INEKernel::configure(win_config.second);
+ Window win{};
+
+ // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
+ if((output->info()->dimension(1) == 1))
+ {
+ const unsigned int num_elems_processed_per_iteration_x = (input0->info()->data_type() == DataType::F32) ? 16 : 32;
+
+ win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x));
+ }
+ else
+ {
+ constexpr unsigned int num_elems_processed_per_iteration_x = 8;
+ constexpr unsigned int num_elems_processed_per_iteration_y = 4;
+
+ win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+ }
+
+ Coordinates coord;
+ coord.set_num_dimensions(output->info()->num_dimensions());
+ output->info()->set_valid_region(ValidRegion(coord, output->info()->tensor_shape()));
+ INEKernel::configure(win);
}
Status NEGEMMMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved,
const GEMMReshapeInfo &reshape_info)
{
ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output, alpha, is_interleaved, reshape_info));
- ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input0->clone().get(), input1->clone().get(), output->clone().get()).first);
return Status{};
}
@@ -998,57 +1143,29 @@ void NEGEMMMatrixMultiplyKernel::run(const Window &window, const ThreadInfo &inf
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
- const bool multiply_alpha = !(helpers::float_ops::is_one(_alpha));
-
// Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
- if((_output->info()->dimension(1) == 1))
+ const bool is_output_vector = (_output->info()->dimension(1) == 1);
+ switch(_input0->info()->data_type())
{
- switch(_input0->info()->data_type())
+ case DataType::F32:
{
- case DataType::F32:
- {
- multiply_alpha ? vector_matrix_multiply_f32<true>(_input0, _input1, _output, window, info, _alpha) :
- vector_matrix_multiply_f32<false>(_input0, _input1, _output, window, info, _alpha);
- break;
- }
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- case DataType::F16:
- {
- multiply_alpha ? vector_matrix_multiply_f16<true>(_input0, _input1, _output, window, info, _alpha) :
- vector_matrix_multiply_f16<false>(_input0, _input1, _output, window, info, _alpha);
- break;
- }
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
- default:
- {
- ARM_COMPUTE_ERROR("Data type not supported");
- break;
- }
+ is_output_vector ? vector_matrix_multiply_f32(_input0, _input1, _output, window, info, _alpha) :
+ matrix_matrix_multiply_f32(_input0, _input1, _output, window, _alpha);
+ break;
}
- }
- else
- {
- switch(_input0->info()->data_type())
- {
- case DataType::F32:
- {
- multiply_alpha ? matrix_matrix_multiply_f32<true>(_input0, _input1, _output, window, _alpha) :
- matrix_matrix_multiply_f32<false>(_input0, _input1, _output, window, _alpha);
- break;
- }
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- case DataType::F16:
- {
- multiply_alpha ? matrix_matrix_multiply_f16<true>(_input0, _input1, _output, window, _alpha) :
- matrix_matrix_multiply_f16<false>(_input0, _input1, _output, window, _alpha);
- break;
- }
+ case DataType::F16:
+ {
+ is_output_vector ? vector_matrix_multiply_f16(_input0, _input1, _output, window, info, _alpha) :
+ matrix_matrix_multiply_f16(_input0, _input1, _output, window, _alpha);
+ break;
+ }
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
- default:
- {
- ARM_COMPUTE_ERROR("Data type not supported");
- break;
- }
+ default:
+ {
+ ARM_COMPUTE_ERROR("Data type not supported");
+ break;
}
}
}
+} // namespace arm_compute
diff --git a/tests/validation/NEON/GEMM.cpp b/tests/validation/NEON/GEMM.cpp
index dfac72f3a5..25e8f28dc3 100644
--- a/tests/validation/NEON/GEMM.cpp
+++ b/tests/validation/NEON/GEMM.cpp
@@ -87,6 +87,20 @@ bool validate_zero_padding(unsigned int dim0_value, unsigned int dim1_value)
return in.info()->padding().empty();
}
+/* Zero padding test for GEMM kernels */
+bool validate_gemm_zero_padding(const TensorShape shape0, const TensorShape shape1)
+{
+ // Create tensors
+ Tensor in0 = create_tensor<Tensor>(shape0, DataType::F32);
+ Tensor in1 = create_tensor<Tensor>(shape1, DataType::F32);
+ Tensor dst;
+
+ // Validate zero-padding
+ NEGEMMMatrixMultiplyKernel gemm;
+ gemm.configure(&in0, &in1, &dst, 1.0, false);
+
+ return in0.info()->padding().empty() && in1.info()->padding().empty() && dst.info()->padding().empty();
+}
} // namespace
TEST_SUITE(NEON)
@@ -182,6 +196,26 @@ template <typename T>
using NEGEMMFixtureDisabledC = GEMMValidationFixture<Tensor, Accessor, NEGEMM, T, true>;
TEST_SUITE(Float)
+DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(framework::dataset::make("In0", { TensorShape(21U, 13U),
+ TensorShape(31U, 1U),
+ TensorShape(31U, 1U),
+ TensorShape(8U, 2U),
+ TensorShape(38U, 12U),
+ TensorShape(32U, 1U)
+ }),
+ framework::dataset::make("In1", { TensorShape(33U, 21U),
+ TensorShape(23U, 31U),
+ TensorShape(23U, 31U),
+ TensorShape(16U, 8U),
+ TensorShape(21U, 38U),
+ TensorShape(17U, 32U)
+ })),
+ shape0, shape1)
+{
+ bool status = validate_gemm_zero_padding(shape0, shape1);
+ ARM_COMPUTE_EXPECT(status, framework::LogLevel::ERRORS);
+}
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_SUITE(FP16)
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMDataset(),