From e7e96e09ff0d3e47797adf197aff2bc39671788c Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Fri, 13 Apr 2018 13:44:10 +0100 Subject: COMPMID-1054 Update RSH's GEMM to add batch+multi support Change-Id: Ib9d91b77f1d51976da4449fa1e6eeeffae307353 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/127876 Tested-by: Jenkins Reviewed-by: Pablo Tello Reviewed-by: Anthony Barbier --- arm_compute/core/NEON/kernels/assembly/gemm_common.hpp | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) (limited to 'arm_compute/core/NEON/kernels/assembly/gemm_common.hpp') diff --git a/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp b/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp index 7f47abcbb9..3919c339bf 100644 --- a/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp +++ b/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp @@ -39,23 +39,35 @@ class GemmCommon { protected: const To *_Aptr=nullptr; int _lda=0; + int _A_batch_stride=0; + int _A_multi_stride=0; const To *_Bptr=nullptr; int _ldb=0; + int _B_multi_stride=0; Tr *_Cptr=nullptr; int _ldc=0; + int _C_batch_stride=0; + int _C_multi_stride=0; public: /* Pass in the pointers to the arrays to be operated on and their * strides. This has a default implementation that just captures them * all in protected members. If B is pretransposed (see below) then the * settings for B here are ignored. */ - virtual void set_arrays(const To *A, const int lda, const To *B, const int ldb, Tr *C, const int ldc) { + virtual void set_arrays(const To *A, const int lda, const int A_batch_stride, const int A_multi_stride, + const To *B, const int ldb, /* batches share B */ const int B_multi_stride, + Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride) { _Aptr = A; _lda = lda; + _A_batch_stride = A_batch_stride; + _A_multi_stride = A_multi_stride; _Bptr = B; _ldb = ldb; + _B_multi_stride = B_multi_stride; _Cptr = C; _ldc = ldc; + _C_batch_stride = C_batch_stride; + _C_multi_stride = C_multi_stride; } /* For threading, we divide the work into some number of units and work @@ -95,7 +107,9 @@ public: /* Total number of bytes of space needed for pretransposed arrays. */ virtual size_t get_B_pretransposed_array_size() const { return 0; } /* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */ - virtual void pretranspose_B_array(void *buffer, const To *, const int) { }; + virtual void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) { }; + /* Set pretransposed data - the void * passed in must previously have been passed to pretranspose_B_array() for the same or a similar GEMM. */ + virtual void set_pretransposed_B_data(void *buffer) { } // Destructor virtual ~GemmCommon() { } -- cgit v1.2.1