diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp index 1baa21fd1b..a5b41cac2f 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp @@ -49,16 +49,16 @@ UniqueGemmCommon<float, float> gemm<float, float>(const CPUInfo &ci, const unsig return UniqueGemmCommon<float, float>(new GemvPretransposed<sgemv_pretransposed, float, float>(&ci, N, K, trB, beta)); } - /* GemvNativeTransposed: requires M=1, no trA or trB, doesn't handle beta */ - if(M == 1 && beta == 1.0f && !trA && !trB) + /* GemvNativeTransposed: requires M=1, no trA or trB, doesn't handle alpha */ + if(M == 1 && alpha == 1.0f && !trA && !trB) { - return UniqueGemmCommon<float, float>(new GemvNativeTransposed<sgemv_trans, float, float>(&ci, N, K, alpha)); + return UniqueGemmCommon<float, float>(new GemvNativeTransposed<sgemv_trans, float, float>(&ci, N, K, beta)); } - /* Native GEMM: requires M to be a multiple of 4, K a multiple of 4, N a + /* Native GEMM: requires M to be a multiple of 4, K at least 4, N a * multiple of 16, doesn't handle alpha and only makes sense for small * sizes. */ - if(N <= 128 && K <= 128 && ((M % 4) == 0) && ((K % 4) == 0) && ((N % 16) == 0) && alpha == 1.0f) + if(N <= 128 && K <= 128 && ((M % 4) == 0) && (K >= 4) && ((N % 16) == 0) && alpha == 1.0f) { return UniqueGemmCommon<float, float>(new GemmNative<sgemm_native_16x4, float, float>(&ci, M, N, K, beta)); } |