aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp27
1 files changed, 13 insertions, 14 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
index c093761614..99f061bde8 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
@@ -29,15 +29,15 @@
#include "gemv_native_transposed.hpp"
#include "gemv_pretransposed.hpp"
-#include "kernels/a32_sgemm_8x6.hpp"
#include "kernels/a64_sgemm_12x8.hpp"
-#include "kernels/a64_sgemm_native_16x4.hpp"
-#include "kernels/a64_sgemv_pretransposed.hpp"
+#include "kernels/a32_sgemm_8x6.hpp"
#include "kernels/a64_sgemv_trans.hpp"
+#include "kernels/a64_sgemv_pretransposed.hpp"
+#include "kernels/a64_sgemm_native_16x4.hpp"
-namespace arm_gemm
-{
-template <>
+namespace arm_gemm {
+
+template<>
UniqueGemmCommon<float, float> gemm<float, float>(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
const unsigned int nbatches, const unsigned int nmulti,
const bool trA, const bool trB, const float alpha, const float beta,
@@ -46,18 +46,17 @@ UniqueGemmCommon<float, float> gemm<float, float>(const CPUInfo &ci, const unsig
if (M==1 && nbatches>1) {
return UniqueGemmCommon<float, float> (new GemvBatched<float, float>(ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
}
+
#ifdef __aarch64__
/* Cases in priority order */
/* GemvPretransposed: requires M=1, alpha=1, and transposed hint set. nbatches must be 1 or we would have returned above so don't test. */
- if(M == 1 && alpha == 1.0f && pretransposed_hint)
- {
- return UniqueGemmCommon<float, float>(new GemvPretransposed<sgemv_pretransposed, float, float>(&ci, N, K, nmulti, trB, beta));
+ if (M==1 && alpha==1.0f && pretransposed_hint) {
+ return UniqueGemmCommon<float, float> (new GemvPretransposed<sgemv_pretransposed, float, float>(&ci, N, K, nmulti, trB, beta));
}
/* GemvNativeTransposed: requires M=1, no trA or trB, doesn't handle alpha */
- if(M == 1 && alpha == 1.0f && !trA && !trB)
- {
- return UniqueGemmCommon<float, float>(new GemvNativeTransposed<sgemv_trans, float, float>(&ci, N, K, nmulti, beta));
+ if (M==1 && alpha==1.0f && !trA && !trB) {
+ return UniqueGemmCommon<float, float> (new GemvNativeTransposed<sgemv_trans, float, float>(&ci, N, K, nmulti, beta));
}
/* Native GEMM: requires K at least 4, N a multiple of 16, doesn't
@@ -69,9 +68,9 @@ UniqueGemmCommon<float, float> gemm<float, float>(const CPUInfo &ci, const unsig
}
/* Blocked GEMM, handles all cases. */
- return UniqueGemmCommon<float, float>(new GemmInterleaved<sgemm_12x8, float, float>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ return UniqueGemmCommon<float, float> (new GemmInterleaved<sgemm_12x8, float, float>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
#else
- return UniqueGemmCommon<float, float>(new GemmInterleaved<sgemm_8x6, float, float>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ return UniqueGemmCommon<float, float> (new GemmInterleaved<sgemm_8x6, float, float>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
#endif
}