diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemv_batched.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemv_batched.hpp | 46 |
1 files changed, 16 insertions, 30 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp b/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp index bb09770efc..d91b44b9a8 100644 --- a/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp @@ -25,84 +25,70 @@ #include "arm_gemm.hpp" -namespace arm_gemm -{ +namespace arm_gemm { /* "Batched GEMV" (where M=1 and nbatches>1) can be executed much more * efficiently as a GEMM (with M'=nbatches and nbatches'=1). This wrapper * implements this. */ -template <typename To, typename Tr> -class GemvBatched : public GemmCommon<To, Tr> -{ +template<typename To, typename Tr> +class GemvBatched : public GemmCommon<To, Tr> { private: UniqueGemmCommon<To, Tr> _subgemm = nullptr; public: GemvBatched(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K, const unsigned int nbatches, const unsigned int nmulti, const bool trA, const bool trB, - const To alpha, const To beta, const int maxthreads, const bool pretransposed_hint) - { + const To alpha, const To beta, const int maxthreads, const bool pretransposed_hint) { /* Just create a subgemm with batches->M */ - _subgemm = gemm<To, Tr>(ci, nbatches, N, K, 1, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint); + _subgemm = gemm<To,Tr>(ci, nbatches, N, K, 1, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint); } void set_arrays(const To *A, const int lda, const int A_batch_stride, const int A_multi_stride, const To *B, const int ldb, const int B_multi_stride, - Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride) override - { + Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride) override { /* A and C's batch stride becomes their new row stride. New batch stride is 0 as nbatches for subgemm is always 1. */ _subgemm->set_arrays(A, A_batch_stride, 0, A_multi_stride, B, ldb, B_multi_stride, C, C_batch_stride, 0, C_multi_stride); } - unsigned int get_window_size() const override - { + unsigned int get_window_size() const override { return _subgemm->get_window_size(); } - void set_nthreads(int nthreads) override - { + void set_nthreads(int nthreads) override { _subgemm->set_nthreads(nthreads); } - void execute(unsigned int start, unsigned int end, int threadid) override - { + void execute(unsigned int start, unsigned int end, int threadid) override { _subgemm->execute(start, end, threadid); } - size_t get_working_size() const override - { + size_t get_working_size() const override { return _subgemm->get_working_size(); } - void set_working_space(void *space) override - { + void set_working_space(void *space) override { _subgemm->set_working_space(space); } - bool B_is_pretransposed() const override - { + bool B_is_pretransposed() const override { return _subgemm->B_is_pretransposed(); } - bool B_pretranspose_required() const override - { + bool B_pretranspose_required() const override { return _subgemm->B_pretranspose_required(); } - size_t get_B_pretransposed_array_size() const override - { + size_t get_B_pretransposed_array_size() const override { return _subgemm->get_B_pretransposed_array_size(); } - void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override - { + void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override { _subgemm->pretranspose_B_array(buffer, B, ldb, B_multi_stride); } - void set_pretransposed_B_data(void *buffer) override - { + void set_pretransposed_B_data(void *buffer) override { _subgemm->set_pretransposed_B_data(buffer); } }; |