aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp10
1 files changed, 5 insertions, 5 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
index 1baa21fd1b..a5b41cac2f 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
@@ -49,16 +49,16 @@ UniqueGemmCommon<float, float> gemm<float, float>(const CPUInfo &ci, const unsig
return UniqueGemmCommon<float, float>(new GemvPretransposed<sgemv_pretransposed, float, float>(&ci, N, K, trB, beta));
}
- /* GemvNativeTransposed: requires M=1, no trA or trB, doesn't handle beta */
- if(M == 1 && beta == 1.0f && !trA && !trB)
+ /* GemvNativeTransposed: requires M=1, no trA or trB, doesn't handle alpha */
+ if(M == 1 && alpha == 1.0f && !trA && !trB)
{
- return UniqueGemmCommon<float, float>(new GemvNativeTransposed<sgemv_trans, float, float>(&ci, N, K, alpha));
+ return UniqueGemmCommon<float, float>(new GemvNativeTransposed<sgemv_trans, float, float>(&ci, N, K, beta));
}
- /* Native GEMM: requires M to be a multiple of 4, K a multiple of 4, N a
+ /* Native GEMM: requires M to be a multiple of 4, K at least 4, N a
* multiple of 16, doesn't handle alpha and only makes sense for small
* sizes. */
- if(N <= 128 && K <= 128 && ((M % 4) == 0) && ((K % 4) == 0) && ((N % 16) == 0) && alpha == 1.0f)
+ if(N <= 128 && K <= 128 && ((M % 4) == 0) && (K >= 4) && ((N % 16) == 0) && alpha == 1.0f)
{
return UniqueGemmCommon<float, float>(new GemmNative<sgemm_native_16x4, float, float>(&ci, M, N, K, beta));
}