aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2019-01-23 11:24:50 +0000
committerMichalis Spyrou <michalis.spyrou@arm.com>2019-01-24 10:19:46 +0000
commit1d480652b820317fc97ccbc3cb517e3b9e8be197 (patch)
treeb3c845ec02cccf89430b95186ed3e3f2ae65e2bd /arm_compute/core/NEON/kernels/assembly/gemm_common.hpp
parent20b527a7029d02d0edda78fd92002cbc430dbe05 (diff)
downloadComputeLibrary-1d480652b820317fc97ccbc3cb517e3b9e8be197.tar.gz
COMPMID-1867: Add u8 and s8 hybrid assembly kernels.
Change-Id: Ifeb005f9d18d19feff11949474cce84d9e03749c Reviewed-on: https://review.mlplatform.org/565 Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/core/NEON/kernels/assembly/gemm_common.hpp')
-rw-r--r--arm_compute/core/NEON/kernels/assembly/gemm_common.hpp122
1 files changed, 83 insertions, 39 deletions
diff --git a/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp b/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp
index 7b4f0149e3..c72f210e56 100644
--- a/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp
+++ b/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2019 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -34,42 +34,19 @@ namespace arm_gemm {
// working space (permute as they go along). This interface should support
// all of them.
-template<typename To, typename Tr>
-class GemmCommon {
-protected:
- const To *_Aptr=nullptr;
- int _lda=0;
- int _A_batch_stride=0;
- int _A_multi_stride=0;
- const To *_Bptr=nullptr;
- int _ldb=0;
- int _B_multi_stride=0;
- Tr *_Cptr=nullptr;
- int _ldc=0;
- int _C_batch_stride=0;
- int _C_multi_stride=0;
-
+// The real GemmCommon class is templated based on the operand and return
+// type. This is an interface class which is independent of those types.
+class IGemmCommon {
public:
/* Pass in the pointers to the arrays to be operated on and their
- * strides. This has a default implementation that just captures them
- * all in protected members. If B is pretransposed (see below) then the
- * settings for B here are ignored.
+ * strides. In the interface class these are passed as void pointers -
+ * the templated version overloads this function with a version which
+ * takes appropriately typed pointers. If B is pretransposed (see
+ * below) then the settings for B here are ignored.
*/
- virtual void set_arrays(const To *A, const int lda, const int A_batch_stride, const int A_multi_stride,
- const To *B, const int ldb, /* batches share B */ const int B_multi_stride,
- Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride) {
- _Aptr = A;
- _lda = lda;
- _A_batch_stride = A_batch_stride;
- _A_multi_stride = A_multi_stride;
- _Bptr = B;
- _ldb = ldb;
- _B_multi_stride = B_multi_stride;
- _Cptr = C;
- _ldc = ldc;
- _C_batch_stride = C_batch_stride;
- _C_multi_stride = C_multi_stride;
- }
+ virtual void set_arrays(const void *A, const int lda, const int A_batch_stride, const int A_multi_stride,
+ const void *B, const int ldb, /* batches share B */ const int B_multi_stride,
+ void *C, const int ldc, const int C_batch_stride, const int C_multi_stride) = 0;
/* For threading, we divide the work into some number of units and work
* out internally what unit corresponds to what work. This returns the
@@ -90,6 +67,9 @@ public:
*/
virtual void set_nthreads(int) { };
+ /* Whether this GEMM can be dynamically scheduled or not. */
+ virtual bool supports_dynamic_scheduling() const { return false; }
+
/* Actually do the work. Provide a threadid to index any per-thread
* buffers, and a start/end range to indicate which work to do. */
virtual void execute(unsigned int, unsigned int, int) = 0;
@@ -107,14 +87,78 @@ public:
virtual bool B_pretranspose_required() const { return false; }
/* Total number of bytes of space needed for pretransposed arrays. */
virtual size_t get_B_pretransposed_array_size() const { return 0; }
- /* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */
- /* Arguments are: output buffer pointer, source pointer, source row stride, source multi stride */
- virtual void pretranspose_B_array(void *, const To *, const int, const int) { };
+ /* Perform pretranspose - arguments are output, input, input row stride and input multi stride. */
+ /* The "real" version of this depends on the templated operand type (see below). */
+ virtual void pretranspose_B_array(void *, const void *, const int, const int) = 0;
/* Set pretransposed data - the void * passed in must previously have been passed to pretranspose_B_array() for the same or a similar GEMM. */
virtual void set_pretransposed_B_data(void *) { }
// Destructor
- virtual ~GemmCommon() { }
+ virtual ~IGemmCommon() { }
+};
+
+/*
+ * "Real" GemmCommon class which is templated on the operand and return types.
+ *
+ * In addition to correctly typed versions of the functions that operate on
+ * operand and return data, this class provides a default implementation of
+ * 'set_arrays' to capture the provided arguments in protected class
+ * members, as essentially any implementation will need these.
+ */
+template<typename To, typename Tr>
+class GemmCommon : public IGemmCommon {
+protected:
+ const To *_Aptr=nullptr;
+ int _lda=0;
+ int _A_batch_stride=0;
+ int _A_multi_stride=0;
+ const To *_Bptr=nullptr;
+ int _ldb=0;
+ int _B_multi_stride=0;
+ Tr *_Cptr=nullptr;
+ int _ldc=0;
+ int _C_batch_stride=0;
+ int _C_multi_stride=0;
+
+public:
+ /* Pass in the pointers to the arrays to be operated on and their
+ * strides (templated version with appropriate types). */
+ virtual void set_arrays(const To *A, const int lda, const int A_batch_stride, const int A_multi_stride,
+ const To *B, const int ldb, /* batches share B */ const int B_multi_stride,
+ Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride) {
+ _Aptr = A;
+ _lda = lda;
+ _A_batch_stride = A_batch_stride;
+ _A_multi_stride = A_multi_stride;
+ _Bptr = B;
+ _ldb = ldb;
+ _B_multi_stride = B_multi_stride;
+ _Cptr = C;
+ _ldc = ldc;
+ _C_batch_stride = C_batch_stride;
+ _C_multi_stride = C_multi_stride;
+ }
+
+ /* Implementation of the void * overload which casts its arguments to the appropriate type. */
+ void set_arrays(const void *A, const int lda, const int A_batch_stride, const int A_multi_stride,
+ const void *B, const int ldb, /* batches share B */ const int B_multi_stride,
+ void *C, const int ldc, const int C_batch_stride, const int C_multi_stride) override {
+ set_arrays(static_cast<const To *>(A), lda, A_batch_stride, A_multi_stride,
+ static_cast<const To *>(B), ldb, B_multi_stride,
+ static_cast<Tr *>(C), ldc, C_batch_stride, C_multi_stride);
+ }
+
+ /*** "Pretransposed" interface ***/
+
+ /* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */
+ /* Arguments are: output buffer pointer, source pointer, source row stride, source multi stride */
+ virtual void pretranspose_B_array(void *, const To *, const int, const int) { };
+
+ /* Implementation of the void * overload which casts its arguments to the appropriate type. */
+ void pretranspose_B_array(void *out, const void *in, const int row_stride, const int multi_stride) override {
+ pretranspose_B_array(out, static_cast<const To *>(in), row_stride, multi_stride);
+ }
+
};
-} // namespace arm_gemm
+} // namespace arm_gemm \ No newline at end of file