diff options
author | Michalis Spyrou <michalis.spyrou@arm.com> | 2018-04-13 13:44:10 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:49:37 +0000 |
commit | e7e96e09ff0d3e47797adf197aff2bc39671788c (patch) | |
tree | b52ecdd7627bdf51b8b8da9b9553cb900460222f /arm_compute/core/NEON | |
parent | 1ed1fc6d3b7d8494ce3bbc5f8b46bfde6fc586f9 (diff) | |
download | ComputeLibrary-e7e96e09ff0d3e47797adf197aff2bc39671788c.tar.gz |
COMPMID-1054 Update RSH's GEMM to add batch+multi support
Change-Id: Ib9d91b77f1d51976da4449fa1e6eeeffae307353
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/127876
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Pablo Tello <pablo.tello@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'arm_compute/core/NEON')
-rw-r--r-- | arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp | 7 | ||||
-rw-r--r-- | arm_compute/core/NEON/kernels/assembly/gemm_common.hpp | 18 |
2 files changed, 21 insertions, 4 deletions
diff --git a/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp b/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp index d6c9931a21..0a541c6db9 100644 --- a/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp +++ b/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp @@ -34,6 +34,9 @@ template<typename Top, typename Tret> using UniqueGemmCommon = std::unique_ptr<GemmCommon<Top, Tret> >; template<typename Top, typename Tret> -UniqueGemmCommon<Top, Tret> gemm(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K, const bool trA, const bool trB, const Tret alpha, const Tret beta, const int maxthreads, const bool pretransposed_hint); - +UniqueGemmCommon<Top, Tret> gemm(const CPUInfo &ci, + const unsigned int M, const unsigned int N, const unsigned int K, + const unsigned int nbatches, const unsigned int nmulti, + const bool trA, const bool trB, const Tret alpha, const Tret beta, + const int maxthreads, const bool pretransposed_hint); } // namespace arm_gemm diff --git a/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp b/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp index 7f47abcbb9..3919c339bf 100644 --- a/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp +++ b/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp @@ -39,23 +39,35 @@ class GemmCommon { protected: const To *_Aptr=nullptr; int _lda=0; + int _A_batch_stride=0; + int _A_multi_stride=0; const To *_Bptr=nullptr; int _ldb=0; + int _B_multi_stride=0; Tr *_Cptr=nullptr; int _ldc=0; + int _C_batch_stride=0; + int _C_multi_stride=0; public: /* Pass in the pointers to the arrays to be operated on and their * strides. This has a default implementation that just captures them * all in protected members. If B is pretransposed (see below) then the * settings for B here are ignored. */ - virtual void set_arrays(const To *A, const int lda, const To *B, const int ldb, Tr *C, const int ldc) { + virtual void set_arrays(const To *A, const int lda, const int A_batch_stride, const int A_multi_stride, + const To *B, const int ldb, /* batches share B */ const int B_multi_stride, + Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride) { _Aptr = A; _lda = lda; + _A_batch_stride = A_batch_stride; + _A_multi_stride = A_multi_stride; _Bptr = B; _ldb = ldb; + _B_multi_stride = B_multi_stride; _Cptr = C; _ldc = ldc; + _C_batch_stride = C_batch_stride; + _C_multi_stride = C_multi_stride; } /* For threading, we divide the work into some number of units and work @@ -95,7 +107,9 @@ public: /* Total number of bytes of space needed for pretransposed arrays. */ virtual size_t get_B_pretransposed_array_size() const { return 0; } /* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */ - virtual void pretranspose_B_array(void *buffer, const To *, const int) { }; + virtual void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) { }; + /* Set pretransposed data - the void * passed in must previously have been passed to pretranspose_B_array() for the same or a similar GEMM. */ + virtual void set_pretransposed_B_data(void *buffer) { } // Destructor virtual ~GemmCommon() { } |