diff options
Diffstat (limited to 'arm_compute/core/NEON/kernels/assembly/gemm_common.hpp')
-rw-r--r-- | arm_compute/core/NEON/kernels/assembly/gemm_common.hpp | 18 |
1 files changed, 16 insertions, 2 deletions
diff --git a/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp b/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp index 7f47abcbb9..3919c339bf 100644 --- a/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp +++ b/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp @@ -39,23 +39,35 @@ class GemmCommon { protected: const To *_Aptr=nullptr; int _lda=0; + int _A_batch_stride=0; + int _A_multi_stride=0; const To *_Bptr=nullptr; int _ldb=0; + int _B_multi_stride=0; Tr *_Cptr=nullptr; int _ldc=0; + int _C_batch_stride=0; + int _C_multi_stride=0; public: /* Pass in the pointers to the arrays to be operated on and their * strides. This has a default implementation that just captures them * all in protected members. If B is pretransposed (see below) then the * settings for B here are ignored. */ - virtual void set_arrays(const To *A, const int lda, const To *B, const int ldb, Tr *C, const int ldc) { + virtual void set_arrays(const To *A, const int lda, const int A_batch_stride, const int A_multi_stride, + const To *B, const int ldb, /* batches share B */ const int B_multi_stride, + Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride) { _Aptr = A; _lda = lda; + _A_batch_stride = A_batch_stride; + _A_multi_stride = A_multi_stride; _Bptr = B; _ldb = ldb; + _B_multi_stride = B_multi_stride; _Cptr = C; _ldc = ldc; + _C_batch_stride = C_batch_stride; + _C_multi_stride = C_multi_stride; } /* For threading, we divide the work into some number of units and work @@ -95,7 +107,9 @@ public: /* Total number of bytes of space needed for pretransposed arrays. */ virtual size_t get_B_pretransposed_array_size() const { return 0; } /* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */ - virtual void pretranspose_B_array(void *buffer, const To *, const int) { }; + virtual void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) { }; + /* Set pretransposed data - the void * passed in must previously have been passed to pretranspose_B_array() for the same or a similar GEMM. */ + virtual void set_pretransposed_B_data(void *buffer) { } // Destructor virtual ~GemmCommon() { } |