aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2018-04-13 13:44:10 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:49:37 +0000
commite7e96e09ff0d3e47797adf197aff2bc39671788c (patch)
treeb52ecdd7627bdf51b8b8da9b9553cb900460222f /src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
parent1ed1fc6d3b7d8494ce3bbc5f8b46bfde6fc586f9 (diff)
downloadComputeLibrary-e7e96e09ff0d3e47797adf197aff2bc39671788c.tar.gz
COMPMID-1054 Update RSH's GEMM to add batch+multi support
Change-Id: Ib9d91b77f1d51976da4449fa1e6eeeffae307353 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/127876 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Pablo Tello <pablo.tello@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp19
1 files changed, 13 insertions, 6 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
index a5b41cac2f..43df1aa779 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
@@ -22,6 +22,7 @@
* SOFTWARE.
*/
#include "arm_gemm.hpp"
+#include "gemm_batched.hpp"
#include "gemm_common.hpp"
#include "gemm_interleaved.hpp"
#include "gemm_native.hpp"
@@ -38,21 +39,27 @@ namespace arm_gemm
{
template <>
UniqueGemmCommon<float, float> gemm<float, float>(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
+ const unsigned int nbatches, const unsigned int nmulti,
const bool trA, const bool trB, const float alpha, const float beta,
const int maxthreads, const bool pretransposed_hint)
{
+ /* Handle "batched GEMM" */
+ if(M == 1 && nbatches > 1)
+ {
+ return UniqueGemmCommon<float, float>(new GemmBatched<float, float>(ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ }
#ifdef __aarch64__
/* Cases in priority order */
- /* GemvPretransposed: requires M=1, alpha=1, and transposed hint set */
+ /* GemvPretransposed: requires M=1, alpha=1, and transposed hint set. nbatches must be 1 or we would have returned above so don't test. */
if(M == 1 && alpha == 1.0f && pretransposed_hint)
{
- return UniqueGemmCommon<float, float>(new GemvPretransposed<sgemv_pretransposed, float, float>(&ci, N, K, trB, beta));
+ return UniqueGemmCommon<float, float>(new GemvPretransposed<sgemv_pretransposed, float, float>(&ci, N, K, nmulti, trB, beta));
}
/* GemvNativeTransposed: requires M=1, no trA or trB, doesn't handle alpha */
if(M == 1 && alpha == 1.0f && !trA && !trB)
{
- return UniqueGemmCommon<float, float>(new GemvNativeTransposed<sgemv_trans, float, float>(&ci, N, K, beta));
+ return UniqueGemmCommon<float, float>(new GemvNativeTransposed<sgemv_trans, float, float>(&ci, N, K, nmulti, beta));
}
/* Native GEMM: requires M to be a multiple of 4, K at least 4, N a
@@ -60,13 +67,13 @@ UniqueGemmCommon<float, float> gemm<float, float>(const CPUInfo &ci, const unsig
* sizes. */
if(N <= 128 && K <= 128 && ((M % 4) == 0) && (K >= 4) && ((N % 16) == 0) && alpha == 1.0f)
{
- return UniqueGemmCommon<float, float>(new GemmNative<sgemm_native_16x4, float, float>(&ci, M, N, K, beta));
+ return UniqueGemmCommon<float, float>(new GemmNative<sgemm_native_16x4, float, float>(&ci, M, N, K, nbatches, nmulti, beta));
}
/* Blocked GEMM, handles all cases. */
- return UniqueGemmCommon<float, float>(new GemmInterleaved<sgemm_12x8, float, float>(&ci, M, N, K, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ return UniqueGemmCommon<float, float>(new GemmInterleaved<sgemm_12x8, float, float>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
#else
- return UniqueGemmCommon<float, float>(new GemmInterleaved<sgemm_8x6, float, float>(&ci, M, N, K, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ return UniqueGemmCommon<float, float>(new GemmInterleaved<sgemm_8x6, float, float>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
#endif
}