aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/NEON/kernels/assembly/gemm_common.hpp')
-rw-r--r--arm_compute/core/NEON/kernels/assembly/gemm_common.hpp18
1 files changed, 16 insertions, 2 deletions
diff --git a/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp b/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp
index 7f47abcbb9..3919c339bf 100644
--- a/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp
+++ b/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp
@@ -39,23 +39,35 @@ class GemmCommon {
protected:
const To *_Aptr=nullptr;
int _lda=0;
+ int _A_batch_stride=0;
+ int _A_multi_stride=0;
const To *_Bptr=nullptr;
int _ldb=0;
+ int _B_multi_stride=0;
Tr *_Cptr=nullptr;
int _ldc=0;
+ int _C_batch_stride=0;
+ int _C_multi_stride=0;
public:
/* Pass in the pointers to the arrays to be operated on and their
* strides. This has a default implementation that just captures them
* all in protected members. If B is pretransposed (see below) then the
* settings for B here are ignored. */
- virtual void set_arrays(const To *A, const int lda, const To *B, const int ldb, Tr *C, const int ldc) {
+ virtual void set_arrays(const To *A, const int lda, const int A_batch_stride, const int A_multi_stride,
+ const To *B, const int ldb, /* batches share B */ const int B_multi_stride,
+ Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride) {
_Aptr = A;
_lda = lda;
+ _A_batch_stride = A_batch_stride;
+ _A_multi_stride = A_multi_stride;
_Bptr = B;
_ldb = ldb;
+ _B_multi_stride = B_multi_stride;
_Cptr = C;
_ldc = ldc;
+ _C_batch_stride = C_batch_stride;
+ _C_multi_stride = C_multi_stride;
}
/* For threading, we divide the work into some number of units and work
@@ -95,7 +107,9 @@ public:
/* Total number of bytes of space needed for pretransposed arrays. */
virtual size_t get_B_pretransposed_array_size() const { return 0; }
/* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */
- virtual void pretranspose_B_array(void *buffer, const To *, const int) { };
+ virtual void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) { };
+ /* Set pretransposed data - the void * passed in must previously have been passed to pretranspose_B_array() for the same or a similar GEMM. */
+ virtual void set_pretransposed_B_data(void *buffer) { }
// Destructor
virtual ~GemmCommon() { }