aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp24
1 files changed, 11 insertions, 13 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
index 43df1aa779..c093761614 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
@@ -22,10 +22,10 @@
* SOFTWARE.
*/
#include "arm_gemm.hpp"
-#include "gemm_batched.hpp"
#include "gemm_common.hpp"
#include "gemm_interleaved.hpp"
#include "gemm_native.hpp"
+#include "gemv_batched.hpp"
#include "gemv_native_transposed.hpp"
#include "gemv_pretransposed.hpp"
@@ -41,12 +41,10 @@ template <>
UniqueGemmCommon<float, float> gemm<float, float>(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
const unsigned int nbatches, const unsigned int nmulti,
const bool trA, const bool trB, const float alpha, const float beta,
- const int maxthreads, const bool pretransposed_hint)
-{
- /* Handle "batched GEMM" */
- if(M == 1 && nbatches > 1)
- {
- return UniqueGemmCommon<float, float>(new GemmBatched<float, float>(ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ const int maxthreads, const bool pretransposed_hint) {
+ /* Handle "batched GEMV" */
+ if (M==1 && nbatches>1) {
+ return UniqueGemmCommon<float, float> (new GemvBatched<float, float>(ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
}
#ifdef __aarch64__
/* Cases in priority order */
@@ -62,12 +60,12 @@ UniqueGemmCommon<float, float> gemm<float, float>(const CPUInfo &ci, const unsig
return UniqueGemmCommon<float, float>(new GemvNativeTransposed<sgemv_trans, float, float>(&ci, N, K, nmulti, beta));
}
- /* Native GEMM: requires M to be a multiple of 4, K at least 4, N a
- * multiple of 16, doesn't handle alpha and only makes sense for small
- * sizes. */
- if(N <= 128 && K <= 128 && ((M % 4) == 0) && (K >= 4) && ((N % 16) == 0) && alpha == 1.0f)
- {
- return UniqueGemmCommon<float, float>(new GemmNative<sgemm_native_16x4, float, float>(&ci, M, N, K, nbatches, nmulti, beta));
+ /* Native GEMM: requires K at least 4, N a multiple of 16, doesn't
+ * handle alpha or transpose. Use for small N/K, or if the blocked GEMM
+ * won't thread properly. */
+ if ((K >= 4) && ((N % 16) == 0) && alpha==1.0f && !trA && !trB &&
+ ((K <= 128 && N <= 128) || (nmulti > 1 && (M/maxthreads) < 8))) {
+ return UniqueGemmCommon<float, float> (new GemmNative<sgemm_native_16x4, float, float>(&ci, M, N, K, nbatches, nmulti, beta));
}
/* Blocked GEMM, handles all cases. */