aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm
diff options
context:
space:
mode:
authorDavid Mansell <David.Mansell@arm.com>2018-05-17 18:51:26 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:52:54 +0000
commitce8f60510210efc0cf1c921fac75efc49bc70edc (patch)
treec9f6fb303593198d783639cce25e09ed160e2d0b /src/core/NEON/kernels/arm_gemm
parent2d008a476f5f09f63574990a93e8bf606ae5629e (diff)
downloadComputeLibrary-ce8f60510210efc0cf1c921fac75efc49bc70edc.tar.gz
COMPMID-1177: Improved native GEMM.
Improve the native GEMM so it can cope with any value for M. Also change the selection code so that the native GEMM is selected if M is small and nmulti is large - Winograd needs GEMMs like this and they don't thread properly with the blocked GEMM. (also rename gemm_batched.hpp back to gemv_batched.hpp) Change-Id: I736c33373ada562cbc0c00540520a58103faa9d5 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/131739 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Pablo Tello <pablo.tello@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp24
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemv_batched.hpp (renamed from src/core/NEON/kernels/arm_gemm/gemm_batched.hpp)8
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp185
3 files changed, 120 insertions, 97 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
index 43df1aa779..c093761614 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
@@ -22,10 +22,10 @@
* SOFTWARE.
*/
#include "arm_gemm.hpp"
-#include "gemm_batched.hpp"
#include "gemm_common.hpp"
#include "gemm_interleaved.hpp"
#include "gemm_native.hpp"
+#include "gemv_batched.hpp"
#include "gemv_native_transposed.hpp"
#include "gemv_pretransposed.hpp"
@@ -41,12 +41,10 @@ template <>
UniqueGemmCommon<float, float> gemm<float, float>(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
const unsigned int nbatches, const unsigned int nmulti,
const bool trA, const bool trB, const float alpha, const float beta,
- const int maxthreads, const bool pretransposed_hint)
-{
- /* Handle "batched GEMM" */
- if(M == 1 && nbatches > 1)
- {
- return UniqueGemmCommon<float, float>(new GemmBatched<float, float>(ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ const int maxthreads, const bool pretransposed_hint) {
+ /* Handle "batched GEMV" */
+ if (M==1 && nbatches>1) {
+ return UniqueGemmCommon<float, float> (new GemvBatched<float, float>(ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
}
#ifdef __aarch64__
/* Cases in priority order */
@@ -62,12 +60,12 @@ UniqueGemmCommon<float, float> gemm<float, float>(const CPUInfo &ci, const unsig
return UniqueGemmCommon<float, float>(new GemvNativeTransposed<sgemv_trans, float, float>(&ci, N, K, nmulti, beta));
}
- /* Native GEMM: requires M to be a multiple of 4, K at least 4, N a
- * multiple of 16, doesn't handle alpha and only makes sense for small
- * sizes. */
- if(N <= 128 && K <= 128 && ((M % 4) == 0) && (K >= 4) && ((N % 16) == 0) && alpha == 1.0f)
- {
- return UniqueGemmCommon<float, float>(new GemmNative<sgemm_native_16x4, float, float>(&ci, M, N, K, nbatches, nmulti, beta));
+ /* Native GEMM: requires K at least 4, N a multiple of 16, doesn't
+ * handle alpha or transpose. Use for small N/K, or if the blocked GEMM
+ * won't thread properly. */
+ if ((K >= 4) && ((N % 16) == 0) && alpha==1.0f && !trA && !trB &&
+ ((K <= 128 && N <= 128) || (nmulti > 1 && (M/maxthreads) < 8))) {
+ return UniqueGemmCommon<float, float> (new GemmNative<sgemm_native_16x4, float, float>(&ci, M, N, K, nbatches, nmulti, beta));
}
/* Blocked GEMM, handles all cases. */
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_batched.hpp b/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
index 385358f615..bb09770efc 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_batched.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
@@ -27,14 +27,18 @@
namespace arm_gemm
{
+
+/* "Batched GEMV" (where M=1 and nbatches>1) can be executed much more
+ * efficiently as a GEMM (with M'=nbatches and nbatches'=1). This wrapper
+ * implements this. */
template <typename To, typename Tr>
-class GemmBatched : public GemmCommon<To, Tr>
+class GemvBatched : public GemmCommon<To, Tr>
{
private:
UniqueGemmCommon<To, Tr> _subgemm = nullptr;
public:
- GemmBatched(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
+ GemvBatched(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
const unsigned int nbatches, const unsigned int nmulti, const bool trA, const bool trB,
const To alpha, const To beta, const int maxthreads, const bool pretransposed_hint)
{
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp
index 8d4a38c36d..2b846c7f10 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp
@@ -23,7 +23,9 @@
*/
#ifdef __aarch64__
+#include <algorithm>
#include <cstddef>
+#include <cstring>
#include <arm_neon.h>
@@ -35,22 +37,35 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo
const int beta0 = (beta == 0.0f) ? 1 : 0;
const int oddones = (K % 4);
+ float dummy_buffer[16];
+
+ std::memset(dummy_buffer, 0, sizeof(dummy_buffer));
+
/* For now, very naive with no blocking */
- for(int y = 0; y < M; y += 4)
- {
- for(int x0 = 0; x0 < N; x0 += 16)
- {
- const float *a_ptr0 = A + (y * lda);
- const float *a_ptr1 = a_ptr0 + lda;
- const float *a_ptr2 = a_ptr1 + lda;
- const float *a_ptr3 = a_ptr2 + lda;
+ for (int y=0; y<M; y+=4) {
+ const int activerows = std::min(M-y, 4);
- const float *b_ptr = B + x0;
+ const float * const a_ptr0_base = A + (y * lda);
+ const float * const a_ptr1_base = (activerows > 1) ? (a_ptr0_base + lda) : dummy_buffer;
+ const float * const a_ptr2_base = (activerows > 2) ? (a_ptr1_base + lda) : dummy_buffer;
+ const float * const a_ptr3_base = (activerows > 3) ? (a_ptr2_base + lda) : dummy_buffer;
+
+ const unsigned long a_incr1 = (activerows > 1) ? 32 : 0;
+ const unsigned long a_incr2 = (activerows > 2) ? 32 : 0;
+ const unsigned long a_incr3 = (activerows > 3) ? 32 : 0;
- float *c_ptr0 = C + (y * ldc) + x0;
- float *c_ptr1 = c_ptr0 + ldc;
- float *c_ptr2 = c_ptr1 + ldc;
- float *c_ptr3 = c_ptr2 + ldc;
+ float *c_ptr0 = C + (y * ldc);
+ float *c_ptr1 = (activerows > 1) ? c_ptr0 + ldc : dummy_buffer;
+ float *c_ptr2 = (activerows > 1) ? c_ptr1 + ldc : dummy_buffer;
+ float *c_ptr3 = (activerows > 1) ? c_ptr2 + ldc : dummy_buffer;
+
+ for (int x0=0; x0<N; x0+=16) {
+ const float *a_ptr0 = a_ptr0_base;
+ const float *a_ptr1 = a_ptr1_base;
+ const float *a_ptr2 = a_ptr2_base;
+ const float *a_ptr3 = a_ptr3_base;
+
+ const float *b_ptr = B + x0;
int loops = ((K + 4) / 8) - 1;
int odds = oddones;
@@ -228,34 +243,34 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo
"ldr b3aq, [%[b_ptr], #48]\n"
// Unroll 2
- "fmla v16.4s, bb0.4s, a0.s[2]\n"
- "fmla v20.4s, bb0.4s, a1.s[2]\n"
- "add %[b_ptr], %[b_ptr], %[ldb]\n"
- "fmla v24.4s, bb0.4s, a2.s[2]\n"
- "fmla v28.4s, bb0.4s, a3.s[2]\n"
- "ldr b0q, [%[b_ptr]]\n"
-
- "fmla v17.4s, bb1.4s, a0.s[2]\n"
- "add %[a_ptr0], %[a_ptr0], #32\n"
- "fmla v21.4s, bb1.4s, a1.s[2]\n"
- "add %[a_ptr1], %[a_ptr1], #32\n"
- "fmla v25.4s, bb1.4s, a2.s[2]\n"
- "add %[a_ptr2], %[a_ptr2], #32\n"
- "fmla v29.4s, bb1.4s, a3.s[2]\n"
- "ldr b1q, [%[b_ptr], #16]\n"
-
- "fmla v18.4s, bb2.4s, a0.s[2]\n"
- "add %[a_ptr3], %[a_ptr3], #32\n"
- "fmla v22.4s, bb2.4s, a1.s[2]\n"
- "fmla v26.4s, bb2.4s, a2.s[2]\n"
- "fmla v30.4s, bb2.4s, a3.s[2]\n"
- "ldr b2q, [%[b_ptr], #32]\n"
-
- "fmla v19.4s, bb3.4s, a0.s[2]\n"
- "fmla v23.4s, bb3.4s, a1.s[2]\n"
- "fmla v27.4s, bb3.4s, a2.s[2]\n"
- "fmla v31.4s, bb3.4s, a3.s[2]\n"
- "ldr b3q, [%[b_ptr], #48]\n"
+ "fmla v16.4s, bb0.4s, a0.s[2]\n"
+ "fmla v20.4s, bb0.4s, a1.s[2]\n"
+ "add %[b_ptr], %[b_ptr], %[ldb]\n"
+ "fmla v24.4s, bb0.4s, a2.s[2]\n"
+ "fmla v28.4s, bb0.4s, a3.s[2]\n"
+ "ldr b0q, [%[b_ptr]]\n"
+
+ "fmla v17.4s, bb1.4s, a0.s[2]\n"
+ "add %[a_ptr0], %[a_ptr0], #32\n"
+ "fmla v21.4s, bb1.4s, a1.s[2]\n"
+ "add %[a_ptr1], %[a_ptr1], %[a_incr1]\n"
+ "fmla v25.4s, bb1.4s, a2.s[2]\n"
+ "add %[a_ptr2], %[a_ptr2], %[a_incr2]\n"
+ "fmla v29.4s, bb1.4s, a3.s[2]\n"
+ "ldr b1q, [%[b_ptr], #16]\n"
+
+ "fmla v18.4s, bb2.4s, a0.s[2]\n"
+ "add %[a_ptr3], %[a_ptr3], %[a_incr3]\n"
+ "fmla v22.4s, bb2.4s, a1.s[2]\n"
+ "fmla v26.4s, bb2.4s, a2.s[2]\n"
+ "fmla v30.4s, bb2.4s, a3.s[2]\n"
+ "ldr b2q, [%[b_ptr], #32]\n"
+
+ "fmla v19.4s, bb3.4s, a0.s[2]\n"
+ "fmla v23.4s, bb3.4s, a1.s[2]\n"
+ "fmla v27.4s, bb3.4s, a2.s[2]\n"
+ "fmla v31.4s, bb3.4s, a3.s[2]\n"
+ "ldr b3q, [%[b_ptr], #48]\n"
// Unroll 3
"fmla v16.4s, b0a.4s, a0.s[3]\n"
@@ -427,35 +442,35 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo
"ldr b3q, [%[b_ptr], #48]\n"
// Unroll 1
- "fmla v16.4s, b0a.4s, a0.s[1]\n"
- "add %[b_ptr], %[b_ptr], %[ldb]\n"
- "fmla v20.4s, b0a.4s, a1.s[1]\n"
- "ldr a3aq, [%[a_ptr3], #16]\n"
- "fmla v24.4s, b0a.4s, a2.s[1]\n"
- "fmla v28.4s, b0a.4s, a3.s[1]\n"
- "ldr b0aq, [%[b_ptr]]\n"
-
- "fmla v17.4s, b1a.4s, a0.s[1]\n"
- "add %[a_ptr0], %[a_ptr0], #32\n"
- "fmla v21.4s, b1a.4s, a1.s[1]\n"
- "add %[a_ptr1], %[a_ptr1], #32\n"
- "fmla v25.4s, b1a.4s, a2.s[1]\n"
- "add %[a_ptr2], %[a_ptr2], #32\n"
- "fmla v29.4s, b1a.4s, a3.s[1]\n"
- "ldr b1aq, [%[b_ptr], #16]\n"
-
- "fmla v18.4s, b2a.4s, a0.s[1]\n"
- "fmla v22.4s, b2a.4s, a1.s[1]\n"
- "add %[a_ptr3], %[a_ptr3], #32\n"
- "fmla v26.4s, b2a.4s, a2.s[1]\n"
- "fmla v30.4s, b2a.4s, a3.s[1]\n"
- "ldr b2aq, [%[b_ptr], #32]\n"
-
- "fmla v19.4s, b3a.4s, a0.s[1]\n"
- "fmla v23.4s, b3a.4s, a1.s[1]\n"
- "fmla v27.4s, b3a.4s, a2.s[1]\n"
- "fmla v31.4s, b3a.4s, a3.s[1]\n"
- "ldr b3aq, [%[b_ptr], #48]\n"
+ "fmla v16.4s, b0a.4s, a0.s[1]\n"
+ "add %[b_ptr], %[b_ptr], %[ldb]\n"
+ "fmla v20.4s, b0a.4s, a1.s[1]\n"
+ "ldr a3aq, [%[a_ptr3], #16]\n"
+ "fmla v24.4s, b0a.4s, a2.s[1]\n"
+ "fmla v28.4s, b0a.4s, a3.s[1]\n"
+ "ldr b0aq, [%[b_ptr]]\n"
+
+ "fmla v17.4s, b1a.4s, a0.s[1]\n"
+ "add %[a_ptr0], %[a_ptr0], #32\n"
+ "fmla v21.4s, b1a.4s, a1.s[1]\n"
+ "add %[a_ptr1], %[a_ptr1], %[a_incr1]\n"
+ "fmla v25.4s, b1a.4s, a2.s[1]\n"
+ "add %[a_ptr2], %[a_ptr2], %[a_incr2]\n"
+ "fmla v29.4s, b1a.4s, a3.s[1]\n"
+ "ldr b1aq, [%[b_ptr], #16]\n"
+
+ "fmla v18.4s, b2a.4s, a0.s[1]\n"
+ "fmla v22.4s, b2a.4s, a1.s[1]\n"
+ "add %[a_ptr3], %[a_ptr3], %[a_incr3]\n"
+ "fmla v26.4s, b2a.4s, a2.s[1]\n"
+ "fmla v30.4s, b2a.4s, a3.s[1]\n"
+ "ldr b2aq, [%[b_ptr], #32]\n"
+
+ "fmla v19.4s, b3a.4s, a0.s[1]\n"
+ "fmla v23.4s, b3a.4s, a1.s[1]\n"
+ "fmla v27.4s, b3a.4s, a2.s[1]\n"
+ "fmla v31.4s, b3a.4s, a3.s[1]\n"
+ "ldr b3aq, [%[b_ptr], #48]\n"
// Unroll 2
"fmla v16.4s, bb0.4s, a0.s[2]\n"
@@ -848,18 +863,24 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo
"str q27, [%[c_ptr2], #48]\n"
"3:\n"
- "str q28, [%[c_ptr3]]\n"
- "str q29, [%[c_ptr3], #16]\n"
- "str q30, [%[c_ptr3], #32]\n"
- "str q31, [%[c_ptr3], #48]\n"
-
- : [a_ptr0] "+r"(a_ptr0), [a_ptr1] "+r"(a_ptr1), [a_ptr2] "+r"(a_ptr2), [a_ptr3] "+r"(a_ptr3),
- [b_ptr] "+r"(b_ptr), [loops] "+r"(loops), [odds] "+r"(odds)
- : [ldb] "r"(ldbb), [oddk] "r"(oddk), [beta0] "r"(beta0), [betaptr] "r"(&beta),
- [c_ptr0] "r"(c_ptr0), [c_ptr1] "r"(c_ptr1), [c_ptr2] "r"(c_ptr2), [c_ptr3] "r"(c_ptr3)
- : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",
- "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31",
- "cc", "memory");
+ "str q28, [%[c_ptr3]]\n"
+ "add %[c_ptr0], %[c_ptr0], #64\n"
+ "str q29, [%[c_ptr3], #16]\n"
+ "add %[c_ptr1], %[c_ptr1], %[a_incr1], LSL #1\n"
+ "str q30, [%[c_ptr3], #32]\n"
+ "add %[c_ptr2], %[c_ptr2], %[a_incr2], LSL #1\n"
+ "str q31, [%[c_ptr3], #48]\n"
+ "add %[c_ptr3], %[c_ptr3], %[a_incr3], LSL #1\n"
+
+ : [a_ptr0] "+r" (a_ptr0), [a_ptr1] "+r" (a_ptr1), [a_ptr2] "+r" (a_ptr2), [a_ptr3] "+r" (a_ptr3),
+ [b_ptr] "+r" (b_ptr), [loops] "+r" (loops), [odds] "+r" (odds),
+ [c_ptr0] "+r" (c_ptr0), [c_ptr1] "+r" (c_ptr1), [c_ptr2] "+r" (c_ptr2), [c_ptr3] "+r" (c_ptr3)
+ : [ldb] "r" (ldbb), [oddk] "r" (oddk), [beta0] "r" (beta0), [betaptr] "r" (&beta),
+ [a_incr1] "r" (a_incr1), [a_incr2] "r" (a_incr2), [a_incr3] "r" (a_incr3)
+ : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",
+ "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31",
+ "cc", "memory"
+ );
}
}
}