diff options
Diffstat (limited to 'src/core/NEON/kernels/assembly/gemm_common.hpp')
-rw-r--r-- | src/core/NEON/kernels/assembly/gemm_common.hpp | 150 |
1 files changed, 77 insertions, 73 deletions
diff --git a/src/core/NEON/kernels/assembly/gemm_common.hpp b/src/core/NEON/kernels/assembly/gemm_common.hpp index a44b774b9d..3b4c025371 100644 --- a/src/core/NEON/kernels/assembly/gemm_common.hpp +++ b/src/core/NEON/kernels/assembly/gemm_common.hpp @@ -23,15 +23,12 @@ */ #pragma once -#include "arm_gemm_compute_iface.hpp" +#include "ndrange.hpp" #include <cstddef> -#include <cassert> - -#define UNUSED(x) (void)(x) - -namespace arm_gemm { +namespace arm_gemm +{ // Abstract class for the GEMM/GEMV functions. // // GEMM implementations may be "native" (never require any input @@ -41,7 +38,8 @@ namespace arm_gemm { // The real GemmCommon class is templated based on the operand and return // type. This is an interface class which is independent of those types. -class IGemmCommon { +class IGemmCommon +{ public: /* Pass in the pointers to the arrays to be operated on and their * strides. This "generic" version uses void *s, the preferred version @@ -50,9 +48,9 @@ public: * the settings for B here are ignored. */ virtual void set_arrays_generic(const void *A, const int lda, const int A_batch_stride, const int A_multi_stride, - const void *B, const int ldb, /* batches share B */ const int B_multi_stride, - void *C, const int ldc, const int C_batch_stride, const int C_multi_stride, - const void *bias, /* no row or batch stride needed */ const int bias_multi_stride) = 0; + const void *B, const int ldb, /* batches share B */ const int B_multi_stride, + void *C, const int ldc, const int C_batch_stride, const int C_multi_stride, + const void *bias, /* no row or batch stride needed */ const int bias_multi_stride) = 0; /** @returns an ndrange containing ranges of the compute space which can be * broken up and parallelised over @@ -71,47 +69,64 @@ public: * This has an empty default implementation, as GEMMs which don't care * about thread count can safely ignore this. */ - virtual void set_nthreads(int) { }; + virtual void set_nthreads(int) {}; /* Whether this GEMM can be dynamically scheduled or not. */ - virtual bool supports_dynamic_scheduling() const { return false; } + virtual bool supports_dynamic_scheduling() const + { + return false; + } /** Main execute member fucntion * @param [in] work_range specifies the range of work we want to be computed, total range defined by get_window_size() * @param [in] thread_locator where are we inside of the thread space * @naram [in] threadid a unique threadid */ - virtual void execute(const ndcoord_t& work_range, const ndcoord_t& thread_locator, int threadid) = 0; + virtual void execute(const ndcoord_t &work_range, const ndcoord_t &thread_locator, int threadid) = 0; /*** Working space interface (optional) ***/ /* Total number of bytes of temporary working space needed. If zero, it's not necessary to call set_working_space(). */ - virtual size_t get_working_size() const { return 0; } + virtual size_t get_working_size() const + { + return 0; + } /* Provide working space buffer - the void * passed in must remain allocated for the duration of any execute calls. */ - virtual void set_working_space(void *) { }; + virtual void set_working_space(void *) {}; /*** "Pretransposed" interface (optional) ***/ /* Is this object set up for pretranspose? If so, pretranspose_array() needs to be called before execute(); */ - virtual bool B_is_pretransposed() const { return false; } + virtual bool B_is_pretransposed() const + { + return false; + } /* Does pretranspose still need to be done? */ - virtual bool B_pretranspose_required() const { return false; } + virtual bool B_pretranspose_required() const + { + return false; + } /* Total number of bytes of space needed for pretransposed arrays. */ - virtual size_t get_B_pretransposed_array_size() const { return 0; } + virtual size_t get_B_pretransposed_array_size() const + { + return 0; + } /* Perform pretranspose - arguments are output, input, input row stride and input multi stride. */ /* The "real" version of this depends on the templated operand type (see below). */ virtual void pretranspose_B_array_generic(void *, const void *, const int, const int) = 0; /* Set pretransposed data - the void * passed in must previously have been passed to pretranspose_B_array() for the same or a similar GEMM. */ - virtual void set_pretransposed_B_data(void *) { } + virtual void set_pretransposed_B_data(void *) + { + } /*** "Quantized bias" interface (optional) ***/ /* Set the bias vector for quantized GEMMs */ - virtual void set_quantized_bias(const int32_t *bias, size_t bias_multi_stride) + virtual void set_quantized_bias(const int32_t *, size_t) { - UNUSED(bias); - UNUSED(bias_multi_stride); } // Destructor - virtual ~IGemmCommon() { } + virtual ~IGemmCommon() + { + } }; /* "Real" GemmCommon class which is templated on the operand and return types. @@ -121,50 +136,53 @@ public: * 'set_arrays' to capture the provided arguments in protected class * members, as essentially any implementation will need these. */ -template<typename To, typename Tr> -class GemmCommon : public IGemmCommon { +template <typename To, typename Tr> +class GemmCommon : public IGemmCommon +{ protected: - const To *_Aptr=nullptr; - int _lda=0; - int _A_batch_stride=0; - int _A_multi_stride=0; - const To *_Bptr=nullptr; - int _ldb=0; - int _B_multi_stride=0; - Tr *_Cptr=nullptr; - int _ldc=0; - int _C_batch_stride=0; - int _C_multi_stride=0; - const Tr *_bias=nullptr; - int _bias_multi_stride=0; + const To *_Aptr = nullptr; + int _lda = 0; + int _A_batch_stride = 0; + int _A_multi_stride = 0; + const To *_Bptr = nullptr; + int _ldb = 0; + int _B_multi_stride = 0; + Tr *_Cptr = nullptr; + int _ldc = 0; + int _C_batch_stride = 0; + int _C_multi_stride = 0; + const Tr *_bias = nullptr; + int _bias_multi_stride = 0; public: /* Pass in the pointers to the arrays to be operated on and their * strides (templated version with appropriate types). */ virtual void set_arrays(const To *A, const int lda, const int A_batch_stride, const int A_multi_stride, - const To *B, const int ldb, /* batches share B */ const int B_multi_stride, - Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride, - const Tr *bias, /* no row or batch stride needed */ const int bias_multi_stride) { - _Aptr = A; - _lda = lda; - _A_batch_stride = A_batch_stride; - _A_multi_stride = A_multi_stride; - _Bptr = B; - _ldb = ldb; - _B_multi_stride = B_multi_stride; - _Cptr = C; - _ldc = ldc; - _C_batch_stride = C_batch_stride; - _C_multi_stride = C_multi_stride; - _bias = bias; + const To *B, const int ldb, /* batches share B */ const int B_multi_stride, + Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride, + const Tr *bias, /* no row or batch stride needed */ const int bias_multi_stride) + { + _Aptr = A; + _lda = lda; + _A_batch_stride = A_batch_stride; + _A_multi_stride = A_multi_stride; + _Bptr = B; + _ldb = ldb; + _B_multi_stride = B_multi_stride; + _Cptr = C; + _ldc = ldc; + _C_batch_stride = C_batch_stride; + _C_multi_stride = C_multi_stride; + _bias = bias; _bias_multi_stride = bias_multi_stride; } /* Implementation of the void * overload which casts its arguments to the appropriate type. */ void set_arrays_generic(const void *A, const int lda, const int A_batch_stride, const int A_multi_stride, - const void *B, const int ldb, /* batches share B */ const int B_multi_stride, - void *C, const int ldc, const int C_batch_stride, const int C_multi_stride, - const void *bias, /* no row or batch stride needed */ const int bias_multi_stride) override { + const void *B, const int ldb, /* batches share B */ const int B_multi_stride, + void *C, const int ldc, const int C_batch_stride, const int C_multi_stride, + const void *bias, /* no row or batch stride needed */ const int bias_multi_stride) override + { set_arrays(static_cast<const To *>(A), lda, A_batch_stride, A_multi_stride, static_cast<const To *>(B), ldb, B_multi_stride, static_cast<Tr *>(C), ldc, C_batch_stride, C_multi_stride, @@ -175,27 +193,13 @@ public: /* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */ /* Arguments are: output buffer pointer, source pointer, source row stride, source multi stride */ - virtual void pretranspose_B_array(void *, const To *, const int, const int) { }; + virtual void pretranspose_B_array(void *, const To *, const int, const int) {}; /* Implementation of the void * overload which casts its arguments to the appropriate type. */ - void pretranspose_B_array_generic(void *out, const void *in, const int row_stride, const int multi_stride) override { + void pretranspose_B_array_generic(void *out, const void *in, const int row_stride, const int multi_stride) override + { pretranspose_B_array(out, static_cast<const To *>(in), row_stride, multi_stride); } }; -template<typename GemmKernel> -inline -int unsigned get_total_window_size(const GemmKernel& kernel) -{ - auto window=kernel.get_window_size(); - - unsigned int total = 1; - for(unsigned i = 0; i != arm_gemm::ndrange_max; ++i) - { - total *= window.get_size(i); - } - - return total; -} - } // namespace arm_gemm |