aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2018-04-13 13:44:10 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:49:37 +0000
commite7e96e09ff0d3e47797adf197aff2bc39671788c (patch)
treeb52ecdd7627bdf51b8b8da9b9553cb900460222f /arm_compute/core/NEON/kernels/assembly/gemm_common.hpp
parent1ed1fc6d3b7d8494ce3bbc5f8b46bfde6fc586f9 (diff)
downloadComputeLibrary-e7e96e09ff0d3e47797adf197aff2bc39671788c.tar.gz
COMPMID-1054 Update RSH's GEMM to add batch+multi support
Change-Id: Ib9d91b77f1d51976da4449fa1e6eeeffae307353 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/127876 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Pablo Tello <pablo.tello@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'arm_compute/core/NEON/kernels/assembly/gemm_common.hpp')
-rw-r--r--arm_compute/core/NEON/kernels/assembly/gemm_common.hpp18
1 files changed, 16 insertions, 2 deletions
diff --git a/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp b/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp
index 7f47abcbb9..3919c339bf 100644
--- a/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp
+++ b/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp
@@ -39,23 +39,35 @@ class GemmCommon {
protected:
const To *_Aptr=nullptr;
int _lda=0;
+ int _A_batch_stride=0;
+ int _A_multi_stride=0;
const To *_Bptr=nullptr;
int _ldb=0;
+ int _B_multi_stride=0;
Tr *_Cptr=nullptr;
int _ldc=0;
+ int _C_batch_stride=0;
+ int _C_multi_stride=0;
public:
/* Pass in the pointers to the arrays to be operated on and their
* strides. This has a default implementation that just captures them
* all in protected members. If B is pretransposed (see below) then the
* settings for B here are ignored. */
- virtual void set_arrays(const To *A, const int lda, const To *B, const int ldb, Tr *C, const int ldc) {
+ virtual void set_arrays(const To *A, const int lda, const int A_batch_stride, const int A_multi_stride,
+ const To *B, const int ldb, /* batches share B */ const int B_multi_stride,
+ Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride) {
_Aptr = A;
_lda = lda;
+ _A_batch_stride = A_batch_stride;
+ _A_multi_stride = A_multi_stride;
_Bptr = B;
_ldb = ldb;
+ _B_multi_stride = B_multi_stride;
_Cptr = C;
_ldc = ldc;
+ _C_batch_stride = C_batch_stride;
+ _C_multi_stride = C_multi_stride;
}
/* For threading, we divide the work into some number of units and work
@@ -95,7 +107,9 @@ public:
/* Total number of bytes of space needed for pretransposed arrays. */
virtual size_t get_B_pretransposed_array_size() const { return 0; }
/* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */
- virtual void pretranspose_B_array(void *buffer, const To *, const int) { };
+ virtual void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) { };
+ /* Set pretransposed data - the void * passed in must previously have been passed to pretranspose_B_array() for the same or a similar GEMM. */
+ virtual void set_pretransposed_B_data(void *buffer) { }
// Destructor
virtual ~GemmCommon() { }