aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/assembly/gemm_common.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/assembly/gemm_common.hpp')
-rw-r--r--src/core/NEON/kernels/assembly/gemm_common.hpp150
1 files changed, 77 insertions, 73 deletions
diff --git a/src/core/NEON/kernels/assembly/gemm_common.hpp b/src/core/NEON/kernels/assembly/gemm_common.hpp
index a44b774b9d..3b4c025371 100644
--- a/src/core/NEON/kernels/assembly/gemm_common.hpp
+++ b/src/core/NEON/kernels/assembly/gemm_common.hpp
@@ -23,15 +23,12 @@
*/
#pragma once
-#include "arm_gemm_compute_iface.hpp"
+#include "ndrange.hpp"
#include <cstddef>
-#include <cassert>
-
-#define UNUSED(x) (void)(x)
-
-namespace arm_gemm {
+namespace arm_gemm
+{
// Abstract class for the GEMM/GEMV functions.
//
// GEMM implementations may be "native" (never require any input
@@ -41,7 +38,8 @@ namespace arm_gemm {
// The real GemmCommon class is templated based on the operand and return
// type. This is an interface class which is independent of those types.
-class IGemmCommon {
+class IGemmCommon
+{
public:
/* Pass in the pointers to the arrays to be operated on and their
* strides. This "generic" version uses void *s, the preferred version
@@ -50,9 +48,9 @@ public:
* the settings for B here are ignored.
*/
virtual void set_arrays_generic(const void *A, const int lda, const int A_batch_stride, const int A_multi_stride,
- const void *B, const int ldb, /* batches share B */ const int B_multi_stride,
- void *C, const int ldc, const int C_batch_stride, const int C_multi_stride,
- const void *bias, /* no row or batch stride needed */ const int bias_multi_stride) = 0;
+ const void *B, const int ldb, /* batches share B */ const int B_multi_stride,
+ void *C, const int ldc, const int C_batch_stride, const int C_multi_stride,
+ const void *bias, /* no row or batch stride needed */ const int bias_multi_stride) = 0;
/** @returns an ndrange containing ranges of the compute space which can be
* broken up and parallelised over
@@ -71,47 +69,64 @@ public:
* This has an empty default implementation, as GEMMs which don't care
* about thread count can safely ignore this.
*/
- virtual void set_nthreads(int) { };
+ virtual void set_nthreads(int) {};
/* Whether this GEMM can be dynamically scheduled or not. */
- virtual bool supports_dynamic_scheduling() const { return false; }
+ virtual bool supports_dynamic_scheduling() const
+ {
+ return false;
+ }
/** Main execute member fucntion
* @param [in] work_range specifies the range of work we want to be computed, total range defined by get_window_size()
* @param [in] thread_locator where are we inside of the thread space
* @naram [in] threadid a unique threadid
*/
- virtual void execute(const ndcoord_t& work_range, const ndcoord_t& thread_locator, int threadid) = 0;
+ virtual void execute(const ndcoord_t &work_range, const ndcoord_t &thread_locator, int threadid) = 0;
/*** Working space interface (optional) ***/
/* Total number of bytes of temporary working space needed. If zero, it's not necessary to call set_working_space(). */
- virtual size_t get_working_size() const { return 0; }
+ virtual size_t get_working_size() const
+ {
+ return 0;
+ }
/* Provide working space buffer - the void * passed in must remain allocated for the duration of any execute calls. */
- virtual void set_working_space(void *) { };
+ virtual void set_working_space(void *) {};
/*** "Pretransposed" interface (optional) ***/
/* Is this object set up for pretranspose? If so, pretranspose_array() needs to be called before execute(); */
- virtual bool B_is_pretransposed() const { return false; }
+ virtual bool B_is_pretransposed() const
+ {
+ return false;
+ }
/* Does pretranspose still need to be done? */
- virtual bool B_pretranspose_required() const { return false; }
+ virtual bool B_pretranspose_required() const
+ {
+ return false;
+ }
/* Total number of bytes of space needed for pretransposed arrays. */
- virtual size_t get_B_pretransposed_array_size() const { return 0; }
+ virtual size_t get_B_pretransposed_array_size() const
+ {
+ return 0;
+ }
/* Perform pretranspose - arguments are output, input, input row stride and input multi stride. */
/* The "real" version of this depends on the templated operand type (see below). */
virtual void pretranspose_B_array_generic(void *, const void *, const int, const int) = 0;
/* Set pretransposed data - the void * passed in must previously have been passed to pretranspose_B_array() for the same or a similar GEMM. */
- virtual void set_pretransposed_B_data(void *) { }
+ virtual void set_pretransposed_B_data(void *)
+ {
+ }
/*** "Quantized bias" interface (optional) ***/
/* Set the bias vector for quantized GEMMs */
- virtual void set_quantized_bias(const int32_t *bias, size_t bias_multi_stride)
+ virtual void set_quantized_bias(const int32_t *, size_t)
{
- UNUSED(bias);
- UNUSED(bias_multi_stride);
}
// Destructor
- virtual ~IGemmCommon() { }
+ virtual ~IGemmCommon()
+ {
+ }
};
/* "Real" GemmCommon class which is templated on the operand and return types.
@@ -121,50 +136,53 @@ public:
* 'set_arrays' to capture the provided arguments in protected class
* members, as essentially any implementation will need these.
*/
-template<typename To, typename Tr>
-class GemmCommon : public IGemmCommon {
+template <typename To, typename Tr>
+class GemmCommon : public IGemmCommon
+{
protected:
- const To *_Aptr=nullptr;
- int _lda=0;
- int _A_batch_stride=0;
- int _A_multi_stride=0;
- const To *_Bptr=nullptr;
- int _ldb=0;
- int _B_multi_stride=0;
- Tr *_Cptr=nullptr;
- int _ldc=0;
- int _C_batch_stride=0;
- int _C_multi_stride=0;
- const Tr *_bias=nullptr;
- int _bias_multi_stride=0;
+ const To *_Aptr = nullptr;
+ int _lda = 0;
+ int _A_batch_stride = 0;
+ int _A_multi_stride = 0;
+ const To *_Bptr = nullptr;
+ int _ldb = 0;
+ int _B_multi_stride = 0;
+ Tr *_Cptr = nullptr;
+ int _ldc = 0;
+ int _C_batch_stride = 0;
+ int _C_multi_stride = 0;
+ const Tr *_bias = nullptr;
+ int _bias_multi_stride = 0;
public:
/* Pass in the pointers to the arrays to be operated on and their
* strides (templated version with appropriate types). */
virtual void set_arrays(const To *A, const int lda, const int A_batch_stride, const int A_multi_stride,
- const To *B, const int ldb, /* batches share B */ const int B_multi_stride,
- Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride,
- const Tr *bias, /* no row or batch stride needed */ const int bias_multi_stride) {
- _Aptr = A;
- _lda = lda;
- _A_batch_stride = A_batch_stride;
- _A_multi_stride = A_multi_stride;
- _Bptr = B;
- _ldb = ldb;
- _B_multi_stride = B_multi_stride;
- _Cptr = C;
- _ldc = ldc;
- _C_batch_stride = C_batch_stride;
- _C_multi_stride = C_multi_stride;
- _bias = bias;
+ const To *B, const int ldb, /* batches share B */ const int B_multi_stride,
+ Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride,
+ const Tr *bias, /* no row or batch stride needed */ const int bias_multi_stride)
+ {
+ _Aptr = A;
+ _lda = lda;
+ _A_batch_stride = A_batch_stride;
+ _A_multi_stride = A_multi_stride;
+ _Bptr = B;
+ _ldb = ldb;
+ _B_multi_stride = B_multi_stride;
+ _Cptr = C;
+ _ldc = ldc;
+ _C_batch_stride = C_batch_stride;
+ _C_multi_stride = C_multi_stride;
+ _bias = bias;
_bias_multi_stride = bias_multi_stride;
}
/* Implementation of the void * overload which casts its arguments to the appropriate type. */
void set_arrays_generic(const void *A, const int lda, const int A_batch_stride, const int A_multi_stride,
- const void *B, const int ldb, /* batches share B */ const int B_multi_stride,
- void *C, const int ldc, const int C_batch_stride, const int C_multi_stride,
- const void *bias, /* no row or batch stride needed */ const int bias_multi_stride) override {
+ const void *B, const int ldb, /* batches share B */ const int B_multi_stride,
+ void *C, const int ldc, const int C_batch_stride, const int C_multi_stride,
+ const void *bias, /* no row or batch stride needed */ const int bias_multi_stride) override
+ {
set_arrays(static_cast<const To *>(A), lda, A_batch_stride, A_multi_stride,
static_cast<const To *>(B), ldb, B_multi_stride,
static_cast<Tr *>(C), ldc, C_batch_stride, C_multi_stride,
@@ -175,27 +193,13 @@ public:
/* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */
/* Arguments are: output buffer pointer, source pointer, source row stride, source multi stride */
- virtual void pretranspose_B_array(void *, const To *, const int, const int) { };
+ virtual void pretranspose_B_array(void *, const To *, const int, const int) {};
/* Implementation of the void * overload which casts its arguments to the appropriate type. */
- void pretranspose_B_array_generic(void *out, const void *in, const int row_stride, const int multi_stride) override {
+ void pretranspose_B_array_generic(void *out, const void *in, const int row_stride, const int multi_stride) override
+ {
pretranspose_B_array(out, static_cast<const To *>(in), row_stride, multi_stride);
}
};
-template<typename GemmKernel>
-inline
-int unsigned get_total_window_size(const GemmKernel& kernel)
-{
- auto window=kernel.get_window_size();
-
- unsigned int total = 1;
- for(unsigned i = 0; i != arm_gemm::ndrange_max; ++i)
- {
- total *= window.get_size(i);
- }
-
- return total;
-}
-
} // namespace arm_gemm