aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
diff options
context:
space:
mode:
authorDavid Mansell <David.Mansell@arm.com>2018-05-17 18:51:26 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:52:54 +0000
commitce8f60510210efc0cf1c921fac75efc49bc70edc (patch)
treec9f6fb303593198d783639cce25e09ed160e2d0b /src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
parent2d008a476f5f09f63574990a93e8bf606ae5629e (diff)
downloadComputeLibrary-ce8f60510210efc0cf1c921fac75efc49bc70edc.tar.gz
COMPMID-1177: Improved native GEMM.
Improve the native GEMM so it can cope with any value for M. Also change the selection code so that the native GEMM is selected if M is small and nmulti is large - Winograd needs GEMMs like this and they don't thread properly with the blocked GEMM. (also rename gemm_batched.hpp back to gemv_batched.hpp) Change-Id: I736c33373ada562cbc0c00540520a58103faa9d5 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/131739 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Pablo Tello <pablo.tello@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp24
1 files changed, 11 insertions, 13 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
index 43df1aa779..c093761614 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
@@ -22,10 +22,10 @@
* SOFTWARE.
*/
#include "arm_gemm.hpp"
-#include "gemm_batched.hpp"
#include "gemm_common.hpp"
#include "gemm_interleaved.hpp"
#include "gemm_native.hpp"
+#include "gemv_batched.hpp"
#include "gemv_native_transposed.hpp"
#include "gemv_pretransposed.hpp"
@@ -41,12 +41,10 @@ template <>
UniqueGemmCommon<float, float> gemm<float, float>(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
const unsigned int nbatches, const unsigned int nmulti,
const bool trA, const bool trB, const float alpha, const float beta,
- const int maxthreads, const bool pretransposed_hint)
-{
- /* Handle "batched GEMM" */
- if(M == 1 && nbatches > 1)
- {
- return UniqueGemmCommon<float, float>(new GemmBatched<float, float>(ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ const int maxthreads, const bool pretransposed_hint) {
+ /* Handle "batched GEMV" */
+ if (M==1 && nbatches>1) {
+ return UniqueGemmCommon<float, float> (new GemvBatched<float, float>(ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
}
#ifdef __aarch64__
/* Cases in priority order */
@@ -62,12 +60,12 @@ UniqueGemmCommon<float, float> gemm<float, float>(const CPUInfo &ci, const unsig
return UniqueGemmCommon<float, float>(new GemvNativeTransposed<sgemv_trans, float, float>(&ci, N, K, nmulti, beta));
}
- /* Native GEMM: requires M to be a multiple of 4, K at least 4, N a
- * multiple of 16, doesn't handle alpha and only makes sense for small
- * sizes. */
- if(N <= 128 && K <= 128 && ((M % 4) == 0) && (K >= 4) && ((N % 16) == 0) && alpha == 1.0f)
- {
- return UniqueGemmCommon<float, float>(new GemmNative<sgemm_native_16x4, float, float>(&ci, M, N, K, nbatches, nmulti, beta));
+ /* Native GEMM: requires K at least 4, N a multiple of 16, doesn't
+ * handle alpha or transpose. Use for small N/K, or if the blocked GEMM
+ * won't thread properly. */
+ if ((K >= 4) && ((N % 16) == 0) && alpha==1.0f && !trA && !trB &&
+ ((K <= 128 && N <= 128) || (nmulti > 1 && (M/maxthreads) < 8))) {
+ return UniqueGemmCommon<float, float> (new GemmNative<sgemm_native_16x4, float, float>(&ci, M, N, K, nbatches, nmulti, beta));
}
/* Blocked GEMM, handles all cases. */