From 5f707736413aeac77818c42838296966f8dc6761 Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Tue, 3 Jul 2018 16:22:02 +0100 Subject: COMPMID-1369: Revert accidental formatting of RSH's repo Pulled latest fixes from David's repo: commit f43ebe932c84083332b0b1a0348241b69dda63a7 Author: David Mansell Date: Tue Jul 3 18:09:01 2018 +0100 Whitespace tidying, fixed comment in gemv_batched imported from ACL. Change-Id: Ie37a623f44e90d88072236cb853ac55ac82d5f51 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/138530 Tested-by: Jenkins Reviewed-by: Georgios Pinitas Reviewed-by: Gian Marco Iodice Reviewed-by: David Mansell Reviewed-by: Anthony Barbier --- .../core/NEON/kernels/assembly/arm_gemm.hpp | 16 + .../core/NEON/kernels/assembly/gemm_common.hpp | 5 +- src/core/NEON/kernels/arm_gemm/asmlib.hpp | 85 +- src/core/NEON/kernels/arm_gemm/buffer_manager.hpp | 198 ++- src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp | 20 +- src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp | 27 +- src/core/NEON/kernels/arm_gemm/gemm_int16.cpp | 9 +- src/core/NEON/kernels/arm_gemm/gemm_int8.cpp | 14 +- .../NEON/kernels/arm_gemm/gemm_interleaved.hpp | 383 +++--- src/core/NEON/kernels/arm_gemm/gemm_native.hpp | 52 +- src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp | 9 +- src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp | 14 +- src/core/NEON/kernels/arm_gemm/gemv_batched.hpp | 46 +- .../kernels/arm_gemm/gemv_native_transposed.hpp | 55 +- .../NEON/kernels/arm_gemm/gemv_pretransposed.hpp | 87 +- .../kernels/arm_gemm/kernels/a32_sgemm_8x6.hpp | 27 +- .../kernels/arm_gemm/kernels/a32_sgemm_8x6/a53.cpp | 630 ++++----- .../arm_gemm/kernels/a32_sgemm_8x6/a55r1.cpp | 621 ++++----- .../arm_gemm/kernels/a32_sgemm_8x6/generic.cpp | 520 ++++---- .../kernels/arm_gemm/kernels/a64_gemm_s16_12x8.hpp | 23 +- .../arm_gemm/kernels/a64_gemm_s16_12x8/generic.cpp | 536 ++++---- .../kernels/arm_gemm/kernels/a64_gemm_s8_12x8.hpp | 31 +- .../arm_gemm/kernels/a64_gemm_s8_12x8/a55r1.cpp | 335 ++--- .../a64_gemm_s8_12x8/dot_toolchain_support.h | 84 +- .../arm_gemm/kernels/a64_gemm_s8_12x8/generic.cpp | 535 ++++---- .../kernels/arm_gemm/kernels/a64_gemm_s8_4x4.hpp | 32 +- .../arm_gemm/kernels/a64_gemm_s8_4x4/generic.cpp | 728 +++++------ .../kernels/arm_gemm/kernels/a64_gemm_u16_12x8.hpp | 23 +- .../arm_gemm/kernels/a64_gemm_u16_12x8/generic.cpp | 534 ++++---- .../kernels/arm_gemm/kernels/a64_gemm_u8_12x8.hpp | 31 +- .../arm_gemm/kernels/a64_gemm_u8_12x8/a55r1.cpp | 335 ++--- .../a64_gemm_u8_12x8/dot_toolchain_support.h | 83 +- .../arm_gemm/kernels/a64_gemm_u8_12x8/generic.cpp | 535 ++++---- .../kernels/arm_gemm/kernels/a64_gemm_u8_4x4.hpp | 28 +- .../arm_gemm/kernels/a64_gemm_u8_4x4/generic.cpp | 460 +++---- .../kernels/arm_gemm/kernels/a64_hgemm_24x8.hpp | 29 +- .../arm_gemm/kernels/a64_hgemm_24x8/a55r1.cpp | 580 +++++---- .../arm_gemm/kernels/a64_hgemm_24x8/generic.cpp | 534 ++++---- .../kernels/arm_gemm/kernels/a64_sgemm_12x8.hpp | 27 +- .../arm_gemm/kernels/a64_sgemm_12x8/a53.cpp | 544 ++++---- .../arm_gemm/kernels/a64_sgemm_12x8/a55.cpp | 596 ++++----- .../arm_gemm/kernels/a64_sgemm_12x8/a55r1.cpp | 554 ++++---- .../arm_gemm/kernels/a64_sgemm_12x8/generic.cpp | 536 ++++---- .../arm_gemm/kernels/a64_sgemm_native_16x4.hpp | 17 +- .../kernels/a64_sgemm_native_16x4/generic.cpp | 1285 ++++++++++--------- .../arm_gemm/kernels/a64_sgemv_pretransposed.hpp | 21 +- .../kernels/a64_sgemv_pretransposed/generic.cpp | 953 +++++++------- .../kernels/arm_gemm/kernels/a64_sgemv_trans.hpp | 17 +- .../arm_gemm/kernels/a64_sgemv_trans/generic.cpp | 1331 ++++++++++---------- src/core/NEON/kernels/arm_gemm/mergeresults.hpp | 37 +- .../arm_gemm/merges/a32_merge_float_8x6.hpp | 253 ++-- .../arm_gemm/merges/a64_merge_float_12x8.hpp | 403 +++--- .../merges/a64_merge_float_to_half_12x8.hpp | 512 +++++--- .../arm_gemm/merges/a64_merge_half_24x8.hpp | 420 +++--- .../arm_gemm/merges/a64_merge_int32_12x8.hpp | 340 +++-- src/core/NEON/kernels/arm_gemm/profiler.hpp | 91 +- src/core/NEON/kernels/arm_gemm/transform.hpp | 67 +- .../transforms/a32_interleave_6way_32bit.hpp | 109 +- .../a32_transpose_interleave_8way_32bit.hpp | 129 +- .../transforms/a64_block16_interleave4_8bit.hpp | 71 +- .../transforms/a64_interleave_8way_16bit.hpp | 131 +- .../transforms/a64_interleave_8way_32bit.hpp | 61 +- .../a64_interleave_8way_half_to_float.hpp | 180 +-- .../a64_transpose_interleave_12way_16bit.hpp | 158 +-- ...64_transpose_interleave_12way_half_to_float.hpp | 119 +- .../a64_transpose_interleave_24way_16bit.hpp | 113 +- .../transforms/transpose_interleave_common.hpp | 218 ++-- src/core/NEON/kernels/arm_gemm/utils.hpp | 27 +- 68 files changed, 9060 insertions(+), 8554 deletions(-) diff --git a/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp b/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp index 0a541c6db9..8d1433dd24 100644 --- a/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp +++ b/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp @@ -33,10 +33,26 @@ namespace arm_gemm { template using UniqueGemmCommon = std::unique_ptr >; +/** Request an object to process a GEMM. + * + * @param[in] ci Describes CPU properties. + * @param[in] M Rows in output matrix C (and input matrix A). + * @param[in] N Columns in output matrix C (and input matrix B). + * @param[in] K Columns of input matrix A (= rows of input matrix B). + * @param[in] nbatches Number of "batched" GEMMs (unique A and C, shared B). + * @param[in] nmulti Number of "multi" GEMMs (unique A, B and C). + * @param[in] trA Does A tensor has rows and columns transposed? + * @param[in] trB Does B tensor has rows and columns transposed? + * @param[in] alpha Scalar multiplier to apply to AB matrix product. + * @param[in] beta Scalar multiplier to apply to input C matrix before adding product. + * @param[in] maxthreads Maximum (and default) number of threads that will call execute method. + * @param[in] pretransposed_hint Can the B tensor can be pretransposed (ie shared across invocations)? + */ template UniqueGemmCommon gemm(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K, const unsigned int nbatches, const unsigned int nmulti, const bool trA, const bool trB, const Tret alpha, const Tret beta, const int maxthreads, const bool pretransposed_hint); + } // namespace arm_gemm diff --git a/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp b/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp index 3919c339bf..b43d6eaca6 100644 --- a/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp +++ b/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp @@ -53,10 +53,11 @@ public: /* Pass in the pointers to the arrays to be operated on and their * strides. This has a default implementation that just captures them * all in protected members. If B is pretransposed (see below) then the - * settings for B here are ignored. */ + * settings for B here are ignored. + */ virtual void set_arrays(const To *A, const int lda, const int A_batch_stride, const int A_multi_stride, const To *B, const int ldb, /* batches share B */ const int B_multi_stride, - Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride) { + Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride) { _Aptr = A; _lda = lda; _A_batch_stride = A_batch_stride; diff --git a/src/core/NEON/kernels/arm_gemm/asmlib.hpp b/src/core/NEON/kernels/arm_gemm/asmlib.hpp index b3fcb33bfb..38f51ae72c 100644 --- a/src/core/NEON/kernels/arm_gemm/asmlib.hpp +++ b/src/core/NEON/kernels/arm_gemm/asmlib.hpp @@ -31,21 +31,21 @@ // used by the workaround. // "Correct" version -#define ASM_PREFETCH(address) "PRFM PLDL1KEEP, " address "\n" -#define ASM_PREFETCHL2(address) "PRFM PLDL2KEEP, " address "\n" -#define ASM_PREFETCHW(address) "PRFM PSTL1KEEP, " address "\n" +#define ASM_PREFETCH(address) "PRFM PLDL1KEEP, " address "\n" +#define ASM_PREFETCHL2(address) "PRFM PLDL2KEEP, " address "\n" +#define ASM_PREFETCHW(address) "PRFM PSTL1KEEP, " address "\n" #define ASM_PREFETCHWL2(address) "PRFM PSTL2KEEP, " address "\n" // Lee's uarchsim hack -//#define ASM_PREFETCH(address) "LDNP x20, x21, " address "\n" +//#define ASM_PREFETCH(address) "LDNP x20, x21, " address "\n" // No preload at all //#define ASM_PREFETCH(address) "" #else // "Correct" versions for AArch32 -#define ASM_PREFETCH(address) "PLD " address "\n" -#define ASM_PREFETCHW(address) "PLDW " address "\n" +#define ASM_PREFETCH(address) "PLD " address "\n" +#define ASM_PREFETCHW(address) "PLDW " address "\n" #endif @@ -53,76 +53,77 @@ * Do some prefetches. */ template -static inline void prefetch_6x(const T *pfp) -{ - __asm __volatile( +static inline void prefetch_6x(const T *pfp) { + __asm __volatile ( ASM_PREFETCH("[%[pfp]]") ASM_PREFETCH("[%[pfp], #64]") ASM_PREFETCH("[%[pfp], #128]") ASM_PREFETCH("[%[pfp], #192]") ASM_PREFETCH("[%[pfp], #256]") ASM_PREFETCH("[%[pfp], #320]") - : - : [pfp] "r"(pfp) - : "memory"); + : + : [pfp] "r" (pfp) + : "memory" + ); } template -static inline void prefetch_5x(const T *pfp) -{ - __asm __volatile( +static inline void prefetch_5x(const T *pfp) { + __asm __volatile ( ASM_PREFETCH("[%[pfp]]") ASM_PREFETCH("[%[pfp], #64]") ASM_PREFETCH("[%[pfp], #128]") ASM_PREFETCH("[%[pfp], #192]") ASM_PREFETCH("[%[pfp], #256]") - : - : [pfp] "r"(pfp) - : "memory"); + : + : [pfp] "r" (pfp) + : "memory" + ); } template -static inline void prefetch_4x(const T *pfp) -{ - __asm __volatile( +static inline void prefetch_4x(const T *pfp) { + __asm __volatile ( ASM_PREFETCH("[%[pfp]]") ASM_PREFETCH("[%[pfp], #64]") ASM_PREFETCH("[%[pfp], #128]") ASM_PREFETCH("[%[pfp], #192]") - : - : [pfp] "r"(pfp) - : "memory"); + : + : [pfp] "r" (pfp) + : "memory" + ); } template -static inline void prefetch_3x(const T *pfp) -{ - __asm __volatile( +static inline void prefetch_3x(const T *pfp) { + __asm __volatile ( ASM_PREFETCH("[%[pfp]]") ASM_PREFETCH("[%[pfp], #64]") ASM_PREFETCH("[%[pfp], #128]") - : - : [pfp] "r"(pfp) - : "memory"); + : + : [pfp] "r" (pfp) + : "memory" + ); } template -static inline void prefetch_2x(const T *pfp) -{ - __asm __volatile( +static inline void prefetch_2x(const T *pfp) { + __asm __volatile ( ASM_PREFETCH("[%[pfp]]") ASM_PREFETCH("[%[pfp], #64]") - : - : [pfp] "r"(pfp) - : "memory"); + : + : [pfp] "r" (pfp) + : "memory" + ); } template -static inline void prefetch_1x(const T *pfp) -{ - __asm __volatile( +static inline void prefetch_1x(const T *pfp) { + __asm __volatile ( ASM_PREFETCH("[%[pfp]]") - : - : [pfp] "r"(pfp) - : "memory"); + : + : [pfp] "r" (pfp) + : "memory" + ); } + diff --git a/src/core/NEON/kernels/arm_gemm/buffer_manager.hpp b/src/core/NEON/kernels/arm_gemm/buffer_manager.hpp index dd74744ebc..03f099d57b 100644 --- a/src/core/NEON/kernels/arm_gemm/buffer_manager.hpp +++ b/src/core/NEON/kernels/arm_gemm/buffer_manager.hpp @@ -38,36 +38,33 @@ #endif -namespace arm_gemm -{ +namespace arm_gemm { + #ifndef NO_MULTI_THREADING -enum class BufferStatus -{ +enum class BufferStatus { IDLE, POPULATING, BUSY }; -class Buffer -{ +class Buffer { private: - const int _maxusers; // Maximum permissible threads. - void *const _storage; // Storage for buffer content. + const int _maxusers; // Maximum permissible threads. + void * const _storage; // Storage for buffer content. - int _numusers; // Actual number of threads (might be lower). + int _numusers; // Actual number of threads (might be lower). - volatile BufferStatus _status = BufferStatus::IDLE; // Status - std::atomic_int _users = {}; // How many users are still using the buffer. - volatile int _index = 0; // Which block of data currently resides in the buffer. + volatile BufferStatus _status = BufferStatus::IDLE; // Status + std::atomic_int _users = { }; // How many users are still using the buffer. + volatile int _index = 0; // Which block of data currently resides in the buffer. - std::mutex _lock = {}; + std::mutex _lock = { }; #ifdef USE_SEMAPHORE - std::condition_variable _cv = {}; + std::condition_variable _cv = { }; #endif template - void populate_buffer(T func) - { + void populate_buffer(T func) { func(_storage); /* Now mark it as ready. */ @@ -78,17 +75,15 @@ private: _cv.notify_all(); } #else - _status = BufferStatus::BUSY; + _status = BufferStatus::BUSY; #endif } public: Buffer(Buffer &) = delete; - Buffer &operator=(Buffer &) = delete; + Buffer &operator= (Buffer &) = delete; - Buffer(void *storage, int maxusers) - : _maxusers(maxusers), _storage(storage), _numusers(maxusers) - { + Buffer(void *storage, int maxusers) : _maxusers(maxusers), _storage(storage), _numusers(maxusers) { _status = BufferStatus::IDLE; } @@ -99,38 +94,32 @@ public: * If it's already being populated by another thread or is ready, return. */ template - void try_populate(const int index, T func) - { - for(;;) - { + void try_populate(const int index, T func) { + for (;;) { #ifdef USE_SEMAPHORE /* If it's busy with a previous index, wait on the semaphore. */ - if((_status == BufferStatus::BUSY) && (_index != index)) - { + if ((_status == BufferStatus::BUSY) && (_index != index)) { std::unique_lock ul(_lock); - if((_status == BufferStatus::BUSY) && (_index != index)) - { + if ((_status == BufferStatus::BUSY) && (_index != index)) { _cv.wait(ul); } } #endif /* Return if another thread is populating it already. */ - if((_index == index) && ((_status == BufferStatus::POPULATING) || (_status == BufferStatus::BUSY))) - { + if ((_index == index) && + ((_status == BufferStatus::POPULATING) || (_status == BufferStatus::BUSY))) { return; } - if(_status == BufferStatus::IDLE) - { + if (_status == BufferStatus::IDLE) { std::lock_guard guard(_lock); /* If the buffer is still idle, we can grab it and populate it. */ - if(_status == BufferStatus::IDLE) - { + if (_status == BufferStatus::IDLE) { _status = BufferStatus::POPULATING; - _index = index; - _users = _numusers; + _index = index; + _users = _numusers; break; } } @@ -141,26 +130,26 @@ public: } template - void *get(const int index, T func) - { + void *get(const int index, T func) { // Loop until we achieve something. - for(;;) - { + for (;;) { // If the index is correct and the buffer status is busy then we can // just return the content. No locking is needed here as the index // cannot change (and status cannot change from BUSY) until all // users have finished. - if((_index == index) && (_status == BufferStatus::BUSY)) - { + if ((_index == index) && (_status == BufferStatus::BUSY)) { return _storage; } + + /* If the buffer still has some previous content, or is being + * populated, we can wait with the semaphore. */ #ifdef USE_SEMAPHORE - if(((_status == BufferStatus::BUSY) && (_index != index)) || (_status == BufferStatus::POPULATING)) - { + if (((_status == BufferStatus::BUSY) && (_index != index)) || + (_status == BufferStatus::POPULATING)) { std::unique_lock ul(_lock); - if(((_status == BufferStatus::BUSY) && (_index != index)) || (_status == BufferStatus::POPULATING)) - { + if (((_status == BufferStatus::BUSY) && (_index != index)) || + (_status == BufferStatus::POPULATING)) { _cv.wait(ul); } } @@ -168,17 +157,15 @@ public: // If it's idle, we need to populate it. The IDLE->POPULATING // transition requires the lock. - if(_status == BufferStatus::IDLE) - { + if (_status == BufferStatus::IDLE) { std::lock_guard guard(_lock); /* If it's still idle, grab it. Otherwise drop through and * we'll do something else next time through the loop. */ - if(_status == BufferStatus::IDLE) - { + if (_status == BufferStatus::IDLE) { _status = BufferStatus::POPULATING; - _index = index; - _users = _numusers; + _index = index; + _users = _numusers; break; } } @@ -194,10 +181,8 @@ public: * simply (atomically) decrement the user count, and if it's hit zero we * flag the buffer as idle. */ - void release(void) - { - if(--_users == 0) - { + void release(void) { + if (--_users == 0) { #ifdef USE_SEMAPHORE std::unique_lock ul(_lock); _status = BufferStatus::IDLE; @@ -211,110 +196,91 @@ public: } /* This is called to change the number of users. */ - void set_numusers(int numusers) - { + void set_numusers(int numusers) { _numusers = std::min(numusers, _maxusers); } }; -class BufferManager -{ + +class BufferManager { private: /* This has to be a vector of Buffer *, because a Buffer cannot be moved * or copied due to atomic members. */ - std::vector _buffers = {}; - const int _maxthreads; - void *const _storage; + std::vector _buffers = { }; + const int _maxthreads; + void * const _storage; public: BufferManager(BufferManager &) = delete; - BufferManager &operator=(BufferManager &) = delete; + BufferManager & operator=(BufferManager &) = delete; // Say how much storage is needed. - static inline size_t get_storage_requirement(const int maxthreads, const size_t buffersize) - { + static inline size_t get_storage_requirement(const int maxthreads, const size_t buffersize) { return buffersize * ((maxthreads == 1) ? 1 : 3); } - BufferManager(const int maxthreads, const size_t buffersize, void *storage) - : _maxthreads(maxthreads), _storage(storage) - { + BufferManager(const int maxthreads, const size_t buffersize, void *storage) : _maxthreads(maxthreads), _storage(storage) { const int numbuffers = (maxthreads == 1) ? 1 : 3; /* We don't need any Buffer objects in single thread mode. */ - if(_maxthreads == 1) - { + if (_maxthreads == 1) { return; } /* Use intptr_t to avoid performing arithmetic on a void * */ intptr_t storage_int = reinterpret_cast(_storage); - for(int i = 0; i < numbuffers; i++) - { + for (int i=0; i(storage_int), _maxthreads)); storage_int += buffersize; } } - ~BufferManager() - { - while(_buffers.size()) - { + ~BufferManager() { + while (_buffers.size()) { delete _buffers.back(); _buffers.pop_back(); } } template - void *get(const int index, T func) - { + void *get(const int index, T func) { /* In single thread mode, we just directly call the populating * function on the (single) buffer, otherwise forward to the * relevant Buffer. */ - if(_maxthreads == 1) - { + if (_maxthreads==1) { func(_storage); return _storage; - } - else - { + } else { return _buffers[index % _buffers.size()]->get(index, func); } } template - void try_populate(const int index, T func) - { + void try_populate(const int index, T func) { /* No need for this in single thread mode. */ - if(_maxthreads == 1) - { + if (_maxthreads==1) { return; } _buffers[index % _buffers.size()]->try_populate(index, func); } - void release(const int index) - { + void release(const int index) { /* No need for this in single thread mode. */ - if(_maxthreads == 1) - { + if (_maxthreads==1) { return; } _buffers[index % _buffers.size()]->release(); } - void set_nthreads(int threads) - { - if(_maxthreads == 1) - { + void set_nthreads(int threads) { + if (_maxthreads==1) { return; } - for(unsigned int i = 0; i < _buffers.size(); i++) - { + for(unsigned int i=0; i<_buffers.size(); i++) { _buffers[i]->set_numusers(threads); } } @@ -329,49 +295,35 @@ public: * All the other methods do nothing. */ -class BufferManager -{ +class BufferManager { private: - void *const _storage; + void * const _storage; public: BufferManager(BufferManager &) = delete; - BufferManager &operator=(BufferManager &) = delete; + BufferManager & operator=(BufferManager &) = delete; - BufferManager(const int maxthreads, const size_t buffersize, void *storage) - : _storage(storage) - { - } + BufferManager(const int maxthreads, const size_t buffersize, void *storage) : _storage(storage) { } - ~BufferManager() - { - } + ~BufferManager() { } // Say how much storage is needed. - static inline size_t get_storage_requirement(const int maxthreads, const size_t buffersize) - { + static inline size_t get_storage_requirement(const int maxthreads, const size_t buffersize) { return buffersize; } template - void try_populate(const int index, T func) - { - } + void try_populate(const int index, T func) { } - void release(const int index) - { - } + void release(const int index) { } template - void *get(const int index, T func) - { + void *get(const int index, T func) { func(_storage); return _storage; } - void set_nthreads(int) - { - } + void set_nthreads(int) { } }; #endif diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp index d1180b13cb..fa12942829 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp @@ -30,33 +30,31 @@ #include "gemm_common.hpp" #include "gemm_interleaved.hpp" -#include "kernels/a32_sgemm_8x6.hpp" #include "kernels/a64_hgemm_24x8.hpp" #include "kernels/a64_sgemm_12x8.hpp" +#include "kernels/a32_sgemm_8x6.hpp" + +namespace arm_gemm { -namespace arm_gemm -{ -template <> +template<> UniqueGemmCommon<__fp16, __fp16> gemm(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K, const unsigned int nbatches, const unsigned int nmulti, const bool trA, const bool trB, const __fp16 alpha, const __fp16 beta, - const int maxthreads, const bool pretransposed_hint) -{ + const int maxthreads, const bool pretransposed_hint) { #ifdef __aarch64__ - // Only consider the native FP16 kernel if it will get built. +// Only consider the native FP16 kernel if it will get built. #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) || defined(FP16_KERNELS) #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC // If the compiler is configured to enable this feature always, then assume it is available at runtime too. - const bool use_fp16 = true; + const bool use_fp16=true; #else // Otherwise, detect at runtime via CPUInfo. - const bool use_fp16 = ci.has_fp16(); + const bool use_fp16=ci.has_fp16(); #endif // If FP16 is supported, use it. - if(use_fp16) - { + if (use_fp16) { return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint)); } #endif diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp index c093761614..99f061bde8 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp @@ -29,15 +29,15 @@ #include "gemv_native_transposed.hpp" #include "gemv_pretransposed.hpp" -#include "kernels/a32_sgemm_8x6.hpp" #include "kernels/a64_sgemm_12x8.hpp" -#include "kernels/a64_sgemm_native_16x4.hpp" -#include "kernels/a64_sgemv_pretransposed.hpp" +#include "kernels/a32_sgemm_8x6.hpp" #include "kernels/a64_sgemv_trans.hpp" +#include "kernels/a64_sgemv_pretransposed.hpp" +#include "kernels/a64_sgemm_native_16x4.hpp" -namespace arm_gemm -{ -template <> +namespace arm_gemm { + +template<> UniqueGemmCommon gemm(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K, const unsigned int nbatches, const unsigned int nmulti, const bool trA, const bool trB, const float alpha, const float beta, @@ -46,18 +46,17 @@ UniqueGemmCommon gemm(const CPUInfo &ci, const unsig if (M==1 && nbatches>1) { return UniqueGemmCommon (new GemvBatched(ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint)); } + #ifdef __aarch64__ /* Cases in priority order */ /* GemvPretransposed: requires M=1, alpha=1, and transposed hint set. nbatches must be 1 or we would have returned above so don't test. */ - if(M == 1 && alpha == 1.0f && pretransposed_hint) - { - return UniqueGemmCommon(new GemvPretransposed(&ci, N, K, nmulti, trB, beta)); + if (M==1 && alpha==1.0f && pretransposed_hint) { + return UniqueGemmCommon (new GemvPretransposed(&ci, N, K, nmulti, trB, beta)); } /* GemvNativeTransposed: requires M=1, no trA or trB, doesn't handle alpha */ - if(M == 1 && alpha == 1.0f && !trA && !trB) - { - return UniqueGemmCommon(new GemvNativeTransposed(&ci, N, K, nmulti, beta)); + if (M==1 && alpha==1.0f && !trA && !trB) { + return UniqueGemmCommon (new GemvNativeTransposed(&ci, N, K, nmulti, beta)); } /* Native GEMM: requires K at least 4, N a multiple of 16, doesn't @@ -69,9 +68,9 @@ UniqueGemmCommon gemm(const CPUInfo &ci, const unsig } /* Blocked GEMM, handles all cases. */ - return UniqueGemmCommon(new GemmInterleaved(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint)); + return UniqueGemmCommon (new GemmInterleaved(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint)); #else - return UniqueGemmCommon(new GemmInterleaved(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint)); + return UniqueGemmCommon (new GemmInterleaved(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint)); #endif } diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp index 7669fe0ff1..317541919b 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp @@ -29,14 +29,13 @@ #include "kernels/a64_gemm_s16_12x8.hpp" -namespace arm_gemm -{ -template <> +namespace arm_gemm { + +template<> UniqueGemmCommon gemm(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K, const unsigned int nbatches, const unsigned int nmulti, const bool trA, const bool trB, const int32_t alpha, const int32_t beta, - const int maxthreads, const bool pretransposed_hint) -{ + const int maxthreads, const bool pretransposed_hint) { return UniqueGemmCommon(new GemmInterleaved(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint)); } diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp index f13406284c..7eff47de68 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp @@ -27,20 +27,18 @@ #include "gemm_common.hpp" #include "gemm_interleaved.hpp" +#include "kernels/a64_gemm_s8_4x4.hpp" #include "kernels/a64_gemm_s16_12x8.hpp" #include "kernels/a64_gemm_s8_12x8.hpp" -#include "kernels/a64_gemm_s8_4x4.hpp" -namespace arm_gemm -{ -template <> +namespace arm_gemm { + +template<> UniqueGemmCommon gemm(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K, const unsigned int nbatches, const unsigned int nmulti, const bool trA, const bool trB, const int32_t alpha, const int32_t beta, - const int maxthreads, const bool pretransposed_hint) -{ - if(ci.has_dotprod()) - { + const int maxthreads, const bool pretransposed_hint) { + if (ci.has_dotprod()) { // Dot product supporting CPUs. This family has a special version for A55r1. return UniqueGemmCommon(new GemmInterleaved(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint)); } diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp index 32c65cd3fb..c304edd1f9 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp @@ -23,8 +23,8 @@ */ #pragma once -#include #include +#include #include @@ -41,23 +41,22 @@ // Some macros used to decide how much working space to allocate. // Round allocations up to the next cache line. -#define ALLOC_ROUND 64 -#define ROUND_UP(x) ((((x) + ALLOC_ROUND - 1) / ALLOC_ROUND) * ALLOC_ROUND) +#define ALLOC_ROUND 64 +#define ROUND_UP(x) ((((x) + ALLOC_ROUND-1) / ALLOC_ROUND) * ALLOC_ROUND) // Implementation of the GemmCommon abstract class. // // This implementation interleaves the source matrices in blocks - good for // larger matrices. -namespace arm_gemm -{ -template -class GemmInterleaved : public GemmCommon -{ +namespace arm_gemm { + +template +class GemmInterleaved : public GemmCommon { typedef typename strategy::operand_type Toi; - typedef typename strategy::result_type Tri; + typedef typename strategy::result_type Tri; /* const properties set by constructor */ - const CPUInfo *const _ci; + const CPUInfo * const _ci; const unsigned int _Msize; const unsigned int _Nsize; @@ -72,173 +71,138 @@ class GemmInterleaved : public GemmCommon const Tr _alpha; const Tr _beta; - const unsigned int _maxthreads; - const bool _pretransposed; + const int _maxthreads; + int _nthreads; + const bool _pretransposed; /* Blocking info */ - unsigned int _k_block = 0; - unsigned int _x_block = 0; - unsigned int _Mround = 0; + unsigned int _k_block=0; + unsigned int _x_block=0; + unsigned int _Mround=0; /* Working space, pretransposed buffer, buffer manager */ - const Toi *_B_transposed = nullptr; - BufferManager *_bm = nullptr; - void *_working_space = nullptr; + const Toi *_B_transposed=nullptr; + BufferManager *_bm=nullptr; + void *_working_space=nullptr; /* We will need to walk through the blocks of B in a few contexts, so * factor that out. */ - class blockwalker - { + class blockwalker { private: /* Size loops, etc. based on our parent's configuration */ const GemmInterleaved &_parent; - /* K and X and multi parameters for current iteration. */ - unsigned int _k0 = 0, _x0 = 0, _multi = 0; + /* K, X and multi parameters for current iteration. */ + unsigned int _k0=0, _x0=0, _multi=0; - unsigned int _index = 0; - bool _done = false; - bool _newkblock = true; - bool _newmulti = true; + unsigned int _index=0; + bool _done=false; + bool _newkblock=true; + bool _newmulti=true; public: - blockwalker(const GemmInterleaved &parent) - : _parent(parent) - { - } + blockwalker(const GemmInterleaved &parent) : _parent(parent) { } - unsigned int xmax() - { + unsigned int xmax() { return std::min(_x0 + _parent._x_block, _parent._Nsize); } - unsigned int kmax() - { + unsigned int kmax() { return std::min(_k0 + _parent._k_block, _parent._Ksize); } /* Advance to the next block, return false at the end. */ - bool advance(void) - { - if(_done) - { + bool advance(void) { + if (_done) { return false; } - _newkblock = false; + _newkblock=false; _x0 += _parent._x_block; - if(_x0 >= _parent._Nsize) - { - _x0 = 0; + if (_x0 >= _parent._Nsize) { + _x0=0; _k0 += _parent._k_block; - if(_k0 >= _parent._Ksize) - { - _k0 = 0; + if (_k0 >= _parent._Ksize) { + _k0=0; _multi++; - if(_multi >= _parent._nmulti) - { - _done = true; + if (_multi >= _parent._nmulti) { + _done=true; return false; } - _newmulti = true; + _newmulti=true; } - _newkblock = true; + _newkblock=true; } _index++; return true; } - unsigned int k0(void) - { - return _k0; - } - unsigned int x0(void) - { - return _x0; - } - unsigned int multi(void) - { - return _multi; - } - unsigned int index(void) - { - return _index; - } - bool done(void) - { - return _done; - } - bool newkblock(void) - { - return _newkblock; - } + unsigned int k0(void) { return _k0; } + unsigned int x0(void) { return _x0; } + unsigned int multi(void) { return _multi; } + unsigned int index(void) { return _index; } + bool done(void) { return _done; } + bool newkblock(void) { return _newkblock; } }; // A working size: One of these needed, regardless of thread count. Divided according to window. - size_t get_a_working_size() const - { + size_t get_a_working_size() const { return ROUND_UP(sizeof(Toi) * _k_block * _Mround * _nbatches); } // B working size: 0, 1 or 3 of these needed depending on pretransposed and threading settings. - size_t get_b_working_size() const - { + size_t get_b_working_size() const { return ROUND_UP(sizeof(Toi) * _x_block * _k_block); } // C working size: One needed per thread. - size_t get_c_working_size() const - { + size_t get_c_working_size() const { return ROUND_UP(sizeof(Tri) * _x_block * strategy::out_height); } // Internal execute function. // This supports both the "pretransposed" and "standard" interfaces via the template parameter. - template - void execute_internal(unsigned int start, unsigned int end, int threadid) - { + template + void execute_internal(unsigned int start, unsigned int end, int threadid) { #ifdef CYCLE_PROFILING profiler prof; #endif - strategy strat(_ci); blockwalker current(*this); - blockwalker next = current; + blockwalker next=current; /* Translate 'start' and 'end' into a position within the batches and rows. */ const unsigned int window_per_batch = _Mround / strategy::out_height; - unsigned int batch_0 = start / window_per_batch; - unsigned int batch_end = end / window_per_batch; + unsigned int batch_0 = start / window_per_batch; + unsigned int batch_end = end / window_per_batch; /* Compute the M values to operate on */ unsigned int m_0 = (start - (batch_0 * window_per_batch)) * strategy::out_height; unsigned int m_max = (end - (batch_end * window_per_batch)) * strategy::out_height; /* Make sure we've been set up correctly. */ - if(pretransposed) - { + if (pretransposed) { assert(_B_transposed); - } - else - { + } else { assert(_bm); } assert(_working_space); int8_t *working_space_bytes = reinterpret_cast(_working_space); - // Private buffers. Treat working_space as an array of C buffers (one per thread) first, followed by the (window-divided) A buffer. + // Private buffers. Treat working_space as an array of C buffers + // (one per thread) first, followed by the (window-divided) A + // buffer. // Set a_panel to the base of the A buffers - compute offsets into it based on M/batches later. - Toi *const a_panel = reinterpret_cast(working_space_bytes + (_maxthreads * get_c_working_size())); - Tri *const c_panel = reinterpret_cast(working_space_bytes + (threadid * get_c_working_size())); + Toi * const a_panel = reinterpret_cast(working_space_bytes + (_maxthreads * get_c_working_size())); + Tri * const c_panel = reinterpret_cast(working_space_bytes + (threadid * get_c_working_size())); // Shared buffers - these come either from BufferManager or _B_transposed. const Toi *b_panel; - if(pretransposed) - { + if (pretransposed) { b_panel = _B_transposed; } @@ -247,33 +211,28 @@ class GemmInterleaved : public GemmCommon // newkblock() is always true on the first iteration, so this will be set properly on the first loop. int kern_k = 0; - for(; !current.done(); current.advance()) - { - if(current.newkblock()) - { + for (;!current.done();current.advance()) { + if (current.newkblock()) { #ifdef CYCLE_PROFILING - auto p = prof.ScopedProfiler(PROFILE_PREPA, (end - start) * strategy::out_height * (current.kmax() - current.k0()) * sizeof(Toi)); + auto p=prof.ScopedProfiler(PROFILE_PREPA, (end - start) * strategy::out_height * (current.kmax()-current.k0()) * sizeof(Toi)); #endif - for(unsigned int batch = batch_0; batch <= batch_end; batch++) - { - unsigned int first_m = (batch == batch_0) ? m_0 : 0; + for (unsigned int batch = batch_0; batch <= batch_end; batch++) { + unsigned int first_m = (batch == batch_0) ? m_0 : 0; unsigned int last_m = (batch == batch_end) ? m_max : _Msize; - if(first_m >= last_m) + if (first_m >= last_m) continue; - if(_trA ^ strategy::A_transpose) - { + + if (_trA ^ strategy::A_transpose) { Transform( - a_panel + ((batch * _Mround + first_m) * _k_block), - this->_Aptr + (batch * this->_A_batch_stride) + (current.multi() * this->_A_multi_stride), - this->_lda, first_m, last_m, current.k0(), current.kmax()); - } - else - { + a_panel + ((batch * _Mround + first_m) * _k_block), + this->_Aptr + (batch * this->_A_batch_stride) + (current.multi() * this->_A_multi_stride), + this->_lda, first_m, last_m, current.k0(), current.kmax()); + } else { Transform( - a_panel + ((batch * _Mround + first_m) * _k_block), - this->_Aptr + (batch * this->_A_batch_stride) + (current.multi() * this->_A_multi_stride), - this->_lda, first_m, last_m, current.k0(), current.kmax()); + a_panel + ((batch * _Mround + first_m) * _k_block), + this->_Aptr + (batch * this->_A_batch_stride) + (current.multi() * this->_A_multi_stride), + this->_lda, first_m, last_m, current.k0(), current.kmax()); } } @@ -284,8 +243,7 @@ class GemmInterleaved : public GemmCommon int bblocks = iceildiv(current.xmax() - current.x0(), strategy::out_width); - if(!pretransposed) - { + if (!pretransposed) { /* Look ahead to the next block and populate it if necessary. * This avoids the populate operation becoming a bottleneck, and * helps keep the threads synchronized (the first thread to get @@ -294,71 +252,60 @@ class GemmInterleaved : public GemmCommon * If we are running single threaded, bm->try_populate() will do * nothing. */ - if(next.advance()) - { - _bm->try_populate(next.index(), [&](void *buffer) - { + if (next.advance()) { + _bm->try_populate(next.index(), [&](void *buffer) { #ifdef CYCLE_PROFILING - auto p = prof.ScopedProfiler(PROFILE_PREPB, (next.xmax() - next.x0()) * (next.kmax() - next.k0()) * sizeof(Toi)); + auto p=prof.ScopedProfiler(PROFILE_PREPB, (next.xmax()-next.x0()) * (next.kmax()-next.k0()) * sizeof(Toi)); #endif Toi *b_panel = reinterpret_cast(buffer); - if(_trB ^ strategy::B_transpose) - { + if (_trB ^ strategy::B_transpose) { Transform( - b_panel, this->_Bptr + (next.multi() * this->_B_multi_stride), this->_ldb, - next.x0(), next.xmax(), next.k0(), next.kmax()); - } - else - { + b_panel, this->_Bptr + (next.multi() * this->_B_multi_stride), this->_ldb, + next.x0(), next.xmax(), next.k0(), next.kmax()); + } else { Transform( - b_panel, this->_Bptr + (next.multi() * this->_B_multi_stride), this->_ldb, - next.x0(), next.xmax(), next.k0(), next.kmax()); + b_panel, this->_Bptr + (next.multi() * this->_B_multi_stride), this->_ldb, + next.x0(), next.xmax(), next.k0(), next.kmax()); } }); } + /* Get the buffer for this iteration from the BufferManager. */ - b_panel = reinterpret_cast(_bm->get(current.index(), [&](void *bpv) - { + b_panel = reinterpret_cast(_bm->get(current.index(), [&](void *bpv) { #ifdef CYCLE_PROFILING - auto p = prof.ScopedProfiler(PROFILE_PREPB, (current.xmax() - current.x0()) * (current.kmax() - current.k0()) * sizeof(Toi)); + auto p=prof.ScopedProfiler(PROFILE_PREPB, (current.xmax()-current.x0()) * (current.kmax()-current.k0()) * sizeof(Toi)); #endif Toi *b_panel = reinterpret_cast(bpv); - if(_trB ^ strategy::B_transpose) - { + if (_trB ^ strategy::B_transpose) { Transform( - b_panel, this->_Bptr + (current.multi() * this->_B_multi_stride), this->_ldb, - current.x0(), current.xmax(), current.k0(), current.kmax()); - } - else - { + b_panel, this->_Bptr + (current.multi() * this->_B_multi_stride), this->_ldb, + current.x0(), current.xmax(), current.k0(), current.kmax()); + } else { Transform( - b_panel, this->_Bptr + (current.multi() * this->_B_multi_stride), this->_ldb, - current.x0(), current.xmax(), current.k0(), current.kmax()); + b_panel, this->_Bptr + (current.multi() * this->_B_multi_stride), this->_ldb, + current.x0(), current.xmax(), current.k0(), current.kmax()); } - })); } /* Do the actual work. */ - for(unsigned int batch = batch_0; batch <= batch_end; batch++) - { - unsigned int first_m = (batch == batch_0) ? m_0 : 0; + for (unsigned int batch = batch_0; batch <= batch_end; batch++) { + unsigned int first_m = (batch == batch_0) ? m_0 : 0; unsigned int last_m = (batch == batch_end) ? m_max : _Msize; const Toi *a_ptr = a_panel + (batch * _Mround + first_m) * _k_block; - if(first_m >= last_m) + if (first_m >= last_m) continue; - for(unsigned int y = first_m; y < last_m; y += strategy::out_height) - { + for (unsigned int y=first_m; y { #ifdef CYCLE_PROFILING - auto p = prof.ScopedProfiler(PROFILE_MERGE, (strategy::out_height * bblocks * strategy::out_width * sizeof(Tr))); + auto p=prof.ScopedProfiler(PROFILE_MERGE, (strategy::out_height * bblocks * strategy::out_width * sizeof(Tr))); #endif MergeResults( - this->_Cptr + (batch * this->_C_batch_stride) + (current.multi() * this->_C_multi_stride), - c_panel, this->_ldc, y, ymax, current.x0(), current.xmax(), - _alpha, (current.k0() == 0 ? _beta : static_cast(1))); + this->_Cptr + (batch * this->_C_batch_stride) + (current.multi() * this->_C_multi_stride), + c_panel, this->_ldc, y, ymax, current.x0(), current.xmax(), + _alpha, (current.k0()==0 ? _beta : static_cast(1))); } } } - if(pretransposed) - { + if (pretransposed) { b_panel += (bblocks * strat.out_width * kern_k); - } - else - { + } else { _bm->release(current.index()); } } @@ -391,14 +335,15 @@ class GemmInterleaved : public GemmCommon public: GemmInterleaved(GemmInterleaved &) = delete; - GemmInterleaved &operator=(GemmInterleaved &) = delete; + GemmInterleaved & operator= (GemmInterleaved &) = delete; /* Constructor */ GemmInterleaved(const CPUInfo *ci, const unsigned int M, const unsigned int N, const unsigned int K, const unsigned int nbatches, const unsigned int nmulti, const bool trA, const bool trB, - const Tr alpha, const Tr beta, const int maxthreads, const bool pretransposed) - : _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmulti(nmulti), _trA(trA), _trB(trB), _alpha(alpha), _beta(beta), _maxthreads(maxthreads), _pretransposed(pretransposed) - { + const Tr alpha, const Tr beta, const int maxthreads, const bool pretransposed) : + _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmulti(nmulti), + _trA(trA), _trB(trB), _alpha(alpha), _beta(beta), + _maxthreads(maxthreads), _nthreads(maxthreads), _pretransposed(pretransposed) { const unsigned int L1_size = ci->get_L1_cache_size(); const unsigned int L2_size = ci->get_L2_cache_size(); @@ -426,7 +371,8 @@ public: // x_block: Work out how many rows (of length k_block) will fit in the L2 // Don't allocate more than 90% of the L2 to allow for overheads, and subtract off the L1 contents. - _x_block = (((L2_size * 9) / 10) - (_k_block * sizeof(Toi) * (strategy::out_width + strategy::out_height))) / (sizeof(Toi) * _k_block); + _x_block = (((L2_size * 9) / 10) - (_k_block * sizeof(Toi) * (strategy::out_width + strategy::out_height))) / + (sizeof(Toi) * _k_block); // Needs to be (at least a single) multiple of the kernel output width. _x_block /= strategy::out_width; @@ -434,7 +380,7 @@ public: // And tune to the presented problem size. int num_x_blocks = iceildiv(N, _x_block); - _x_block = iceildiv(N, num_x_blocks); + _x_block = iceildiv(N, num_x_blocks); _x_block = iceildiv(_x_block, strategy::out_width); _x_block *= strategy::out_width; @@ -450,45 +396,36 @@ public: // out work in units of out_height. Factor batches into the window, but // not multi for now (as this would cause problems with the buffer // manager). - - unsigned int get_window_size() const override - { + unsigned int get_window_size() const override { // _Mround is a multiple of out_height by definition. return (_Mround / strategy::out_height) * _nbatches; } // set_nthreads: pass on to buffer manager to avoid it waiting for non-existant threads. - void set_nthreads(int nthreads) override - { - if(_bm) - { - _bm->set_nthreads(nthreads); + void set_nthreads(int nthreads) override { + _nthreads = std::min(nthreads, _maxthreads); + if (_bm) { + _bm->set_nthreads(_nthreads); } } // Execute - void execute(unsigned int start, unsigned int end, int threadid) override - { - if(_pretransposed) - { + void execute(unsigned int start, unsigned int end, int threadid) override { + if (_pretransposed) { execute_internal(start, end, threadid); - } - else - { + } else { execute_internal(start, end, threadid); } } // Interface implementation - working space - size_t get_working_size() const override - { + size_t get_working_size() const override { // In all cases, we need one A buffer plus a C buffer per thread. size_t size = get_a_working_size() + (get_c_working_size() * _maxthreads); // For pretransposed case, there is no working space needed for B. // Otherwise, we need a BufferManager. - if(!_pretransposed) - { + if (!_pretransposed) { size += BufferManager::get_storage_requirement(_maxthreads, get_b_working_size()); } @@ -497,33 +434,28 @@ public: return size; } - void set_working_space(void *working_space) override - { + void set_working_space(void *working_space) override { // Make sure everything ends up cache line aligned int8_t *working_space_bytes = reinterpret_cast(working_space); - intptr_t working_space_int = reinterpret_cast(working_space); + intptr_t working_space_int = reinterpret_cast(working_space); - size_t diff = 0; + size_t diff=0; - if(working_space_int & 0x3F) - { + if (working_space_int & 0x3F) { diff = 0x40 - (working_space_int & 0x3F); } working_space_bytes += diff; - if(_pretransposed) - { + if (_pretransposed) { // Pretransposed case: just set internal pointer to parameter value. _working_space = reinterpret_cast(working_space_bytes); - } - else - { + } else { // Otherwise, use the first part of the working space for the buffer manager. // It's legal to call this again so don't leak a buffer manager if it already existed. delete _bm; - _bm = new BufferManager(_maxthreads, get_b_working_size(), reinterpret_cast(working_space_bytes)); + _bm = new BufferManager(_nthreads, get_b_working_size(), reinterpret_cast(working_space_bytes)); working_space_bytes += BufferManager::get_storage_requirement(_maxthreads, get_b_working_size()); @@ -532,24 +464,20 @@ public: } // Interface implementation - pretransposed - bool B_is_pretransposed() const override - { + bool B_is_pretransposed() const override { return _pretransposed; } - bool B_pretranspose_required() const override - { - return _pretransposed && (_B_transposed == nullptr); + bool B_pretranspose_required() const override { + return _pretransposed && (_B_transposed==nullptr); } // TODO: this could almost certainly be considerably simpler. - size_t get_B_pretransposed_array_size() const override - { - size_t total = 0; + size_t get_B_pretransposed_array_size() const override { + size_t total=0; blockwalker current(*this); - do - { + do { /* Figure out the size of each block. */ size_t x_size = (current.xmax() - current.x0()); size_t k_size = (current.kmax() - current.k0()); @@ -562,20 +490,17 @@ public: k_size *= strategy::k_unroll; total += x_size * k_size * sizeof(Toi); - } - while(current.advance()); + } while (current.advance()); return total; } - void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override - { + void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override { blockwalker current(*this); - Toi *buffer = reinterpret_cast(in_buffer); - _B_transposed = buffer; + Toi *buffer = reinterpret_cast(in_buffer); + _B_transposed = buffer; - do - { + do { /* Figure out the size of each block. */ size_t x_size = (current.xmax() - current.x0()); size_t k_size = (current.kmax() - current.k0()); @@ -587,31 +512,25 @@ public: k_size = iceildiv(k_size, strategy::k_unroll); k_size *= strategy::k_unroll; - if(_trB ^ strategy::B_transpose) - { + if (_trB ^ strategy::B_transpose) { Transform( - buffer, B + (current.multi() * B_multi_stride), ldb, - current.x0(), current.xmax(), current.k0(), current.kmax()); - } - else - { + buffer, B + (current.multi() * B_multi_stride), ldb, + current.x0(), current.xmax(), current.k0(), current.kmax()); + } else { Transform( - buffer, B + (current.multi() * B_multi_stride), ldb, - current.x0(), current.xmax(), current.k0(), current.kmax()); + buffer, B + (current.multi() * B_multi_stride), ldb, + current.x0(), current.xmax(), current.k0(), current.kmax()); } buffer += (x_size * k_size); - } - while(current.advance()); + } while (current.advance()); } - void set_pretransposed_B_data(void *in_buffer) override - { + void set_pretransposed_B_data(void *in_buffer) override { _B_transposed = reinterpret_cast(in_buffer); } - ~GemmInterleaved() override - { + ~GemmInterleaved() override { delete _bm; } }; diff --git a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp index 695236bdc4..6fed645d82 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp @@ -34,8 +34,8 @@ #include "profiler.hpp" #endif -namespace arm_gemm -{ +namespace arm_gemm { + // Implementation of the GemmCommon abstract class. // // This is implementation is for native GEMM with no transposition. @@ -43,11 +43,10 @@ namespace arm_gemm // By default the source data is used in-place, but if type conversion is // needed we need to allocate working space (CURRENTLY NOT IMPLEMENTED). -template -class GemmNative : public GemmCommon -{ +template +class GemmNative : public GemmCommon { typedef typename strategy::operand_type Toi; - typedef typename strategy::result_type Tri; + typedef typename strategy::result_type Tri; const unsigned int _Msize; const unsigned int _Nsize; @@ -58,36 +57,34 @@ class GemmNative : public GemmCommon Tr _beta; - const CPUInfo *const _ci; + const CPUInfo * const _ci; - unsigned int k_block = 0; - unsigned int n_block = 0; + unsigned int k_block=0; + unsigned int n_block=0; public: GemmNative(GemmNative &) = delete; - GemmNative &operator=(GemmNative &) = delete; + GemmNative & operator= (GemmNative &) = delete; - GemmNative(const CPUInfo *ci, const unsigned int M, const unsigned int N, const unsigned int K, const unsigned int nbatches, const unsigned int nmultis, const Tr beta) - : _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmultis(nmultis), _beta(beta), _ci(ci) - { + GemmNative(const CPUInfo *ci, const unsigned int M, const unsigned int N, const unsigned int K, const unsigned int nbatches, const unsigned int nmultis, const Tr beta) : + _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmultis(nmultis), _beta(beta), _ci(ci) { /* For now don't do any blocking. TODO: figure out if we should. */ k_block = K; n_block = N; } // Window is number of out_height blocks - unsigned int get_window_size() const override - { + unsigned int get_window_size() const override { return iceildiv(_Msize, strategy::out_height) * _nbatches * _nmultis; } // Actually execute the GEMM. - void execute(unsigned int start, unsigned int end, int) override - { + void execute(unsigned int start, unsigned int end, int) override { #ifdef CYCLE_PROFILING profiler prof; #endif - strategy strat(_ci); + strategy strat(_ci); + const unsigned int window_per_batch = iceildiv(_Msize, strategy::out_height); const unsigned int window_per_multi = window_per_batch * _nbatches; @@ -103,27 +100,24 @@ public: static_assert(std::is_same::value, "gemm_native: Operand types must be the same."); static_assert(std::is_same::value, "gemm_native: Result types must be the same."); - for(unsigned int multi = first_multi; multi <= last_multi; multi++) - { + for (unsigned int multi=first_multi; multi<=last_multi; multi++) { const unsigned int batch_0 = (multi == first_multi) ? first_batch : 0; - const unsigned int batch_max = (multi == last_multi) ? last_batch : _nbatches - 1; + const unsigned int batch_max = (multi == last_multi) ? last_batch : (_nbatches-1); - for(unsigned int batch = batch_0; batch <= batch_max; batch++) - { - const unsigned int m_start = ((multi == first_multi) && (batch == first_batch)) ? first_row : 0; - const unsigned int m_end = ((multi == last_multi) && (batch == last_batch)) ? last_row : _Msize; + for (unsigned int batch=batch_0; batch <= batch_max; batch++) { + const unsigned int m_start = ((multi == first_multi) && (batch==first_batch)) ? first_row : 0; + const unsigned int m_end = ((multi == last_multi) && (batch==last_batch)) ? last_row : _Msize; - for(unsigned int y0 = m_start; y0 < m_end; y0 += strategy::out_height) - { + for (unsigned int y0=m_start; y0_Aptr + (multi * this->_A_multi_stride) + (batch * this->_A_batch_stride) + (y0 * this->_lda), this->_lda, this->_Bptr + (multi * this->_B_multi_stride), this->_ldb, this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (y0 * this->_ldc), this->_ldc, - _beta, (ymax - y0), _Nsize, _Ksize); + _beta, (ymax-y0), _Nsize, _Ksize); } } } diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp index 8f1f377aaf..4e8b811e83 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp @@ -29,14 +29,13 @@ #include "kernels/a64_gemm_u16_12x8.hpp" -namespace arm_gemm -{ -template <> +namespace arm_gemm { + +template<> UniqueGemmCommon gemm(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K, const unsigned int nbatches, const unsigned int nmulti, const bool trA, const bool trB, uint32_t alpha, uint32_t beta, - const int maxthreads, const bool pretransposed_hint) -{ + const int maxthreads, const bool pretransposed_hint) { return UniqueGemmCommon(new GemmInterleaved(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint)); } diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp index 0e9f3f255c..321aa65d83 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp @@ -27,19 +27,17 @@ #include "gemm_common.hpp" #include "gemm_interleaved.hpp" -#include "kernels/a64_gemm_u8_12x8.hpp" #include "kernels/a64_gemm_u8_4x4.hpp" +#include "kernels/a64_gemm_u8_12x8.hpp" + +namespace arm_gemm { -namespace arm_gemm -{ -template <> +template<> UniqueGemmCommon gemm(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K, const unsigned int nbatches, const unsigned int nmulti, const bool trA, const bool trB, const uint32_t alpha, const uint32_t beta, - const int maxthreads, const bool pretransposed_hint) -{ - if(ci.has_dotprod()) - { + const int maxthreads, const bool pretransposed_hint) { + if (ci.has_dotprod()) { // Dot product supporting CPUs. This family has a special version for A55r1. return UniqueGemmCommon(new GemmInterleaved(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint)); } diff --git a/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp b/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp index bb09770efc..d91b44b9a8 100644 --- a/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp @@ -25,84 +25,70 @@ #include "arm_gemm.hpp" -namespace arm_gemm -{ +namespace arm_gemm { /* "Batched GEMV" (where M=1 and nbatches>1) can be executed much more * efficiently as a GEMM (with M'=nbatches and nbatches'=1). This wrapper * implements this. */ -template -class GemvBatched : public GemmCommon -{ +template +class GemvBatched : public GemmCommon { private: UniqueGemmCommon _subgemm = nullptr; public: GemvBatched(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K, const unsigned int nbatches, const unsigned int nmulti, const bool trA, const bool trB, - const To alpha, const To beta, const int maxthreads, const bool pretransposed_hint) - { + const To alpha, const To beta, const int maxthreads, const bool pretransposed_hint) { /* Just create a subgemm with batches->M */ - _subgemm = gemm(ci, nbatches, N, K, 1, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint); + _subgemm = gemm(ci, nbatches, N, K, 1, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint); } void set_arrays(const To *A, const int lda, const int A_batch_stride, const int A_multi_stride, const To *B, const int ldb, const int B_multi_stride, - Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride) override - { + Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride) override { /* A and C's batch stride becomes their new row stride. New batch stride is 0 as nbatches for subgemm is always 1. */ _subgemm->set_arrays(A, A_batch_stride, 0, A_multi_stride, B, ldb, B_multi_stride, C, C_batch_stride, 0, C_multi_stride); } - unsigned int get_window_size() const override - { + unsigned int get_window_size() const override { return _subgemm->get_window_size(); } - void set_nthreads(int nthreads) override - { + void set_nthreads(int nthreads) override { _subgemm->set_nthreads(nthreads); } - void execute(unsigned int start, unsigned int end, int threadid) override - { + void execute(unsigned int start, unsigned int end, int threadid) override { _subgemm->execute(start, end, threadid); } - size_t get_working_size() const override - { + size_t get_working_size() const override { return _subgemm->get_working_size(); } - void set_working_space(void *space) override - { + void set_working_space(void *space) override { _subgemm->set_working_space(space); } - bool B_is_pretransposed() const override - { + bool B_is_pretransposed() const override { return _subgemm->B_is_pretransposed(); } - bool B_pretranspose_required() const override - { + bool B_pretranspose_required() const override { return _subgemm->B_pretranspose_required(); } - size_t get_B_pretransposed_array_size() const override - { + size_t get_B_pretransposed_array_size() const override { return _subgemm->get_B_pretransposed_array_size(); } - void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override - { + void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override { _subgemm->pretranspose_B_array(buffer, B, ldb, B_multi_stride); } - void set_pretransposed_B_data(void *buffer) override - { + void set_pretransposed_B_data(void *buffer) override { _subgemm->set_pretransposed_B_data(buffer); } }; diff --git a/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp index e5cc79eaed..241c5fea27 100644 --- a/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp @@ -34,8 +34,8 @@ #include "profiler.hpp" #endif -namespace arm_gemm -{ +namespace arm_gemm { + // Implementation of the GemmCommon abstract class. // // This is implementation is for a "native" (no-transform) GEMV with a @@ -43,53 +43,48 @@ namespace arm_gemm // // As a native operation the source data is used in-place, so the internal // and external operand/result types must match. -template -class GemvNativeTransposed : public GemmCommon -{ +template +class GemvNativeTransposed : public GemmCommon { typedef typename strategy::operand_type Toi; - typedef typename strategy::result_type Tri; + typedef typename strategy::result_type Tri; const unsigned int _Nsize; const unsigned int _Ksize; + const unsigned int _nmultis; const Tr _beta; - const CPUInfo *const _ci; + const CPUInfo * const _ci; - unsigned int m_block = 0; - unsigned int n_block = 0; + unsigned int m_block=0; + unsigned int n_block=0; public: GemvNativeTransposed(GemvNativeTransposed &) = delete; - GemvNativeTransposed &operator=(GemvNativeTransposed &) = delete; + GemvNativeTransposed & operator= (GemvNativeTransposed &) = delete; - GemvNativeTransposed(const CPUInfo *ci, const unsigned int N, const unsigned int K, const unsigned int nmultis, const Tr beta) - : _Nsize(N), _Ksize(K), _nmultis(nmultis), _beta(beta), _ci(ci) - { + GemvNativeTransposed(const CPUInfo *ci, const unsigned int N, const unsigned int K, const unsigned int nmultis, const Tr beta) : _Nsize(N), _Ksize(K), _nmultis(nmultis), _beta(beta), _ci(ci) { /* For now don't do any blocking. TODO: figure out if we should. */ m_block = K; n_block = N; } // Window is number of out_width blocks times number of multis. - unsigned int get_window_size() const override - { + unsigned int get_window_size() const override { return iceildiv(_Nsize, strategy::out_width) * _nmultis; } // Actually execute the GEMV. - void execute(unsigned int start, unsigned int end, int) override - { + void execute(unsigned int start, unsigned int end, int) override { #ifdef CYCLE_PROFILING profiler prof; #endif - strategy strat(_ci); const unsigned int window_per_multi = iceildiv(_Nsize, strategy::out_width); - const unsigned int multi_0 = start / window_per_multi; - const unsigned int multi_end = end / window_per_multi; + const unsigned int multi_0 = start / window_per_multi; + const unsigned int multi_end = end / window_per_multi; const unsigned int n_0 = (start - (multi_0 * window_per_multi)) * strategy::out_width; const unsigned int n_max = (end - (multi_end * window_per_multi)) * strategy::out_width; @@ -97,27 +92,25 @@ public: static_assert(std::is_same::value, "gemv_transposed: Operand types must be the same."); static_assert(std::is_same::value, "gemv_transposed: Result types must be the same."); - for(unsigned int multi = multi_0; multi <= multi_end; multi++) - { - const unsigned int n_start = (multi == multi_0) ? n_0 : 0; - const unsigned int n_end = (multi == multi_end) ? n_max : _Nsize; + for (unsigned int multi=multi_0; multi<=multi_end; multi++) { + const unsigned int n_start = (multi==multi_0) ? n_0 : 0; + const unsigned int n_end = (multi==multi_end) ? n_max : _Nsize; - if(n_end <= n_start) + if (n_end <= n_start) continue; - for(unsigned int m0 = 0; m0 < _Ksize; m0 += m_block) - { + for (unsigned int m0=0; m0<_Ksize; m0+=m_block) { unsigned int mmax = std::min(m0 + m_block, _Ksize); - for(unsigned int n0 = n_start; n0 < n_end; n0 += n_block) - { + + for (unsigned int n0=n_start; n0_Bptr + (multi * this->_B_multi_stride) + (m0 * this->_ldb) + n0, this->_Aptr + (multi * this->_A_multi_stride) + m0, this->_Cptr + (multi * this->_C_multi_stride) + n0, - _beta, this->_ldb, (mmax - m0), (nmax - n0)); + _beta, this->_ldb, (mmax-m0), (nmax-n0)); } } } diff --git a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp index 770ee033c8..e53ddb26c1 100644 --- a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp @@ -34,66 +34,64 @@ #include "profiler.hpp" #endif -namespace arm_gemm -{ +namespace arm_gemm { + // Implementation of the GemmCommon abstract class. // // This is implementation is for GEMV with pretransposition. +// // batches are not supported as a batched GEMV makes no sense (can be converted to a GEMM). - -template -class GemvPretransposed : public GemmCommon -{ +template +class GemvPretransposed : public GemmCommon { typedef typename strategy::operand_type Toi; - typedef typename strategy::result_type Tri; + typedef typename strategy::result_type Tri; const unsigned int _Nsize; const unsigned int _Ksize; + const unsigned int _nmultis; const bool _trB; const Tr _beta; - const CPUInfo *const _ci; - const unsigned int _buffer_per_multi; + const CPUInfo * const _ci; - unsigned int m_block = 0; - unsigned int n_block = 0; + const unsigned int _buffer_per_multi; + + unsigned int m_block=0; + unsigned int n_block=0; const Toi *_A_pretransposed = nullptr; public: GemvPretransposed(GemvPretransposed &) = delete; - GemvPretransposed &operator=(GemvPretransposed &) = delete; + GemvPretransposed & operator= (GemvPretransposed &) = delete; - GemvPretransposed(const CPUInfo *ci, const unsigned int N, const unsigned int K, const unsigned int nmultis, const bool trB, const Tr beta) - : _Nsize(N), _Ksize(K), _nmultis(nmultis), _trB(trB), _beta(beta), _ci(ci), _buffer_per_multi(_Ksize * iceildiv(_Nsize, strategy::A_interleave) * strategy::A_interleave) - { + GemvPretransposed(const CPUInfo *ci, const unsigned int N, const unsigned int K, const unsigned int nmultis, const bool trB, const Tr beta) : + _Nsize(N), _Ksize(K), _nmultis(nmultis), _trB(trB), _beta(beta), _ci(ci), + _buffer_per_multi(_Ksize * iceildiv(_Nsize, strategy::A_interleave) * strategy::A_interleave) { /* For now don't do any blocking. TODO: figure out if we should. */ m_block = K; n_block = N; } // Window is number of out_width blocks, times number of multis. - unsigned int get_window_size() const override - { + unsigned int get_window_size() const override { return iceildiv(_Nsize, strategy::out_width) * _nmultis; } // Actually execute the GEMV. - void execute(unsigned int start, unsigned int end, int) override - { + void execute(unsigned int start, unsigned int end, int) override { #ifdef CYCLE_PROFILING profiler prof; #endif - strategy strat(_ci); /* Break the window values down into multis of interest... */ const unsigned int window_per_multi = iceildiv(_Nsize, strategy::out_width); - const unsigned int multi_0 = start / window_per_multi; - const unsigned int multi_end = end / window_per_multi; + const unsigned int multi_0 = start / window_per_multi; + const unsigned int multi_end = end / window_per_multi; /* ... and figure out where we start and end in the first and last multi. */ const unsigned int n_0 = (start - (multi_0 * window_per_multi)) * strategy::out_width; @@ -101,66 +99,56 @@ public: static_assert(std::is_same::value, "GemvPretransposed: Result types must be the same."); - for(unsigned int multi = multi_0; multi <= multi_end; multi++) - { - const unsigned int n_start = (multi == multi_0) ? n_0 : 0; - const unsigned int n_end = (multi == multi_end) ? n_max : _Nsize; + for (unsigned int multi=multi_0; multi<=multi_end; multi++) { + const unsigned int n_start = (multi==multi_0) ? n_0 : 0; + const unsigned int n_end = (multi==multi_end) ? n_max : _Nsize; - if(n_end <= n_start) + if (n_end <= n_start) continue; - for(unsigned int m0 = 0; m0 < _Ksize; m0 += m_block) - { + for (unsigned int m0=0; m0<_Ksize; m0+=m_block) { unsigned int mmax = std::min(m0 + m_block, _Ksize); - for(unsigned int n = n_start; n < n_end; n += n_block) - { + + for (unsigned int n=n_start; n_Bptr below instead */ strat.kernel(_A_pretransposed + (multi * _buffer_per_multi) + (n * _Ksize) + (m0 * strategy::A_interleave), (_Ksize * strategy::A_interleave), this->_Aptr + (multi * this->_A_multi_stride) + m0, this->_Cptr + (multi * this->_C_multi_stride) + n, - _beta, (mmax - m0), (nmax - n)); + _beta, (mmax-m0), (nmax-n)); } } } } /* Pretransposed interface implementation */ - bool B_is_pretransposed() const override - { + bool B_is_pretransposed() const override { return true; } - bool B_pretranspose_required() const override - { + bool B_pretranspose_required() const override { /* Transpose is required if _A_pretransposed is still nullptr */ return (_A_pretransposed == nullptr); } - size_t get_B_pretransposed_array_size() const override - { + size_t get_B_pretransposed_array_size() const override { return _buffer_per_multi * _nmultis * sizeof(To); } - void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override - { + void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override { Toi *A_buffer = reinterpret_cast(buffer); - for(unsigned int multi = 0; multi < _nmultis; multi++) - { + for (unsigned int multi=0; multi<_nmultis; multi++) { /* Reverse sense here as we are dealing with B rather than A. So if * strategy::A_transpose is false and _trB is false, we still * transpose. */ - if(_trB ^ strategy::A_transpose) - { + if (_trB ^ strategy::A_transpose) { Transform(A_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _Nsize, 0, _Ksize); - } - else - { + } else { Transform(A_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _Nsize, 0, _Ksize); } } @@ -168,8 +156,7 @@ public: _A_pretransposed = A_buffer; } - void set_pretransposed_B_data(void *buffer) override - { + void set_pretransposed_B_data(void *buffer) override { _A_pretransposed = reinterpret_cast(buffer); } }; diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6.hpp index de11dc582c..01bf1f9297 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6.hpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6.hpp @@ -25,8 +25,8 @@ #ifdef __arm__ -namespace arm_gemm -{ +namespace arm_gemm { + // Actual kernel implementations void a32_sgemm_8x6(const float *, const float *, float *, int, int, int); void a32_sgemm_8x6_a53(const float *, const float *, float *, int, int, int); @@ -40,8 +40,7 @@ void a32_sgemm_8x6_a55r1(const float *, const float *, float *, int, int, int); // All kernels in the family must share these characteristics. The actual // kernel to be used can be chosen at runtime, based on the CPU_type // structure. -class sgemm_8x6 -{ +class sgemm_8x6 { public: typedef float operand_type; typedef float result_type; @@ -50,25 +49,23 @@ public: /* Describes the data layout for A input */ static const int A_interleave = 6; - static const int A_block = 1; - static const int A_transpose = 0; + static const int A_block = 1; + static const int A_transpose = 0; /* Same for B input */ static const int B_interleave = 8; - static const int B_block = 1; - static const int B_transpose = 1; + static const int B_block = 1; + static const int B_transpose = 1; /* Kernel blocking parameters */ - static const int out_width = 8; + static const int out_width = 8; static const int out_height = 6; - static const int k_unroll = 1; + static const int k_unroll = 1; kern_type kernel = a32_sgemm_8x6; - sgemm_8x6(const CPUInfo *ci) - { - switch(ci->get_cpu_model()) - { + sgemm_8x6(const CPUInfo *ci) { + switch(ci->get_cpu_model()) { case CPUModel::A53: kernel = a32_sgemm_8x6_a53; break; @@ -78,7 +75,7 @@ public: break; default: - kernel = a32_sgemm_8x6; + /* Generic kernel is selected by default. */ break; } } diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/a53.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/a53.cpp index 428498f79e..e3844d8825 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/a53.cpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/a53.cpp @@ -37,360 +37,370 @@ // Note that the intent of this is that either ablocks or bblocks will be 1 // - this construction allows the output loop to proceed in either order. -namespace arm_gemm -{ -void a32_sgemm_8x6_a53(const float *Apanel, const float *Bpanel, float *Cpanel, int ablocks, int bblocks, int K) -{ +namespace arm_gemm { + +void a32_sgemm_8x6_a53(const float *Apanel, const float *Bpanel, float *Cpanel, int ablocks, int bblocks, int K) { const float *a_ptr = Apanel; - float *c_ptr = Cpanel; + float *c_ptr = Cpanel; - for(int yb = 0; yb < ablocks; yb++) - { + for (int yb=0; ybget_cpu_model() == CPUModel::A55r1) - { + gemm_s8_12x8(const CPUInfo *ci) { + if (ci->get_cpu_model() == CPUModel::A55r1) { kernel = a64_gemm_s8_12x8_a55r1; } } diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8/a55r1.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8/a55r1.cpp index ef2f29183c..eaa7979a31 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8/a55r1.cpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8/a55r1.cpp @@ -31,40 +31,37 @@ #include "dot_toolchain_support.h" #endif -namespace arm_gemm -{ -void a64_gemm_s8_12x8_a55r1(const int8_t *Apanel, const int8_t *Bpanel, int32_t *Cpanel, const int ablocks, const int bblocks, const int K) -{ +namespace arm_gemm { + +void a64_gemm_s8_12x8_a55r1(const int8_t *Apanel, const int8_t *Bpanel, int32_t *Cpanel, const int ablocks, const int bblocks, const int K) { const int8_t *a_ptr = Apanel; - int32_t *c_ptr = Cpanel; + int32_t *c_ptr = Cpanel; // We divide K by 4 because the sdot instruction processes 4 elements at a time. - const int W = K / 4; + const int W = K/4; // Fix up for odd lengths - set a flag if K is odd, but make // sure we round up the iteration count. - const int oddk = (W & 1); - const int k_iters = ((W + 1) / 2) - 1; + const int oddk = (W & 1); + const int k_iters = ((W+1)/2) - 1; - for(int yb = 0; yb < ablocks; yb++) - { + for (int yb=0; ybget_cpu_model() == CPUModel::A55r1) - { + gemm_u8_12x8(const CPUInfo *ci) { + if (ci->get_cpu_model() == CPUModel::A55r1) { kernel = a64_gemm_u8_12x8_a55r1; } } diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8/a55r1.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8/a55r1.cpp index f8fafbdf84..994aea65f7 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8/a55r1.cpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8/a55r1.cpp @@ -31,40 +31,37 @@ #include "dot_toolchain_support.h" #endif -namespace arm_gemm -{ -void a64_gemm_u8_12x8_a55r1(const uint8_t *Apanel, const uint8_t *Bpanel, uint32_t *Cpanel, const int ablocks, const int bblocks, const int K) -{ +namespace arm_gemm { + +void a64_gemm_u8_12x8_a55r1(const uint8_t *Apanel, const uint8_t *Bpanel, uint32_t *Cpanel, const int ablocks, const int bblocks, const int K) { const uint8_t *a_ptr = Apanel; - uint32_t *c_ptr = Cpanel; + uint32_t *c_ptr = Cpanel; // We divide K by 4 because the udot instruction processes 4 elements at a time. - const int W = K / 4; + const int W = K/4; // Fix up for odd lengths - set a flag if K is odd, but make // sure we round up the iteration count. - const int oddk = (W & 1); - const int k_iters = ((W + 1) / 2) - 1; + const int oddk = (W & 1); + const int k_iters = ((W+1)/2) - 1; - for(int yb = 0; yb < ablocks; yb++) - { + for (int yb=0; ybget_cpu_model() == CPUModel::A55r1) - { + hgemm_24x8(const CPUInfo *ci) { + if (ci->get_cpu_model() == CPUModel::A55r1) { kernel = a64_hgemm_asimd_24x8_a55r1; } } diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8/a55r1.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8/a55r1.cpp index 2186117536..a3839ce07b 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8/a55r1.cpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8/a55r1.cpp @@ -39,25 +39,22 @@ // Note that the intent of this is that either ablocks or bblocks will be 1 // - this construction allows the output loop to proceed in either order. -namespace arm_gemm -{ -void a64_hgemm_asimd_24x8_a55r1(const __fp16 *Apanel, const __fp16 *Bpanel, __fp16 *Cpanel, int ablocks, int bblocks, int K) -{ +namespace arm_gemm { + +void a64_hgemm_asimd_24x8_a55r1(const __fp16 *Apanel, const __fp16 *Bpanel, __fp16 *Cpanel, int ablocks, int bblocks, int K) { const __fp16 *a_ptr = Apanel; - __fp16 *c_ptr = Cpanel; + __fp16 *c_ptr = Cpanel; // Fix up for odd lengths - set a flag if K is odd, but make // sure we round up the iteration count. - int oddk = (K & 1); - int k_iters = ((K + 1) / 2) - 1; + int oddk = (K & 1); + int k_iters = ((K+1)/2) - 1; - for(int yb = 0; yb < ablocks; yb++) - { + for (int yb=0; ybget_cpu_model()) - { + switch(ci->get_cpu_model()) { case CPUModel::A53: kernel = a64_sgemm_asimd_12x8_a53; break; diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/a53.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/a53.cpp index 618ebc733c..24001915c5 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/a53.cpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/a53.cpp @@ -27,333 +27,347 @@ #include "../../asmlib.hpp" -namespace arm_gemm -{ -void a64_sgemm_asimd_12x8_a53(const float *Apanel, const float *Bpanel, float *Cpanel, int ablocks, int bblocks, int K) -{ +namespace arm_gemm { + +void a64_sgemm_asimd_12x8_a53(const float *Apanel, const float *Bpanel, float *Cpanel, int ablocks, int bblocks, int K) { const float *a_ptr = Apanel; - float *c_ptr = Cpanel; + float *c_ptr = Cpanel; - for(int yb = 0; yb < ablocks; yb++) - { + for (int yb=0; yb -namespace arm_gemm -{ -void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, float *C, int ldc, float beta, int M, int N, int K) -{ - const int oddk = ((K % 8) >= 4) ? 1 : 0; - const int beta0 = (beta == 0.0f) ? 1 : 0; +namespace arm_gemm { + +void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, float *C, int ldc, float beta, int M, int N, int K) { + const int oddk = ((K % 8) >= 4) ? 1 : 0; + const int beta0 = (beta == 0.0f) ? 1 : 0; const int oddones = (K % 4); float dummy_buffer[16]; @@ -67,12 +66,12 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo const float *b_ptr = B + x0; - int loops = ((K + 4) / 8) - 1; - int odds = oddones; + int loops = ((K+4)/8) - 1; + int odds = oddones; size_t ldbb = ldb * sizeof(float); - __asm __volatile( + __asm __volatile ( "a0 .req v0\n" "a1 .req v1\n" "a2 .req v2\n" @@ -107,140 +106,140 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo "b2aq .req q14\n" "b3aq .req q15\n" - "movi v16.4s, #0x0\n" - "ldr a0q, [%[a_ptr0]]\n" - "movi v17.4s, #0x0\n" - "ldr b0q, [%[b_ptr]]\n" - "movi v18.4s, #0x0\n" - "ldr b1q, [%[b_ptr], #16]\n" - "movi v19.4s, #0x0\n" - "ldr b2q, [%[b_ptr], #32]\n" - "movi v20.4s, #0x0\n" - "ldr b3q, [%[b_ptr], #48]\n" - "movi v21.4s, #0x0\n" - "add %[b_ptr], %[b_ptr], %[ldb]\n" - "ldr a1q, [%[a_ptr1]]\n" - "movi v22.4s, #0x0\n" - "ldr a2q, [%[a_ptr2]]\n" - "movi v23.4s, #0x0\n" - "ldr a3q, [%[a_ptr3]]\n" - "movi v24.4s, #0x0\n" - "ldr b0aq, [%[b_ptr]]\n" - "movi v25.4s, #0x0\n" - "ldr b1aq, [%[b_ptr], #16]\n" - "movi v26.4s, #0x0\n" - "ldr b2aq, [%[b_ptr], #32]\n" - "cbz %w[beta0], 5f\n" - "movi v27.4s, #0x0\n" - "movi v28.4s, #0x0\n" - "movi v29.4s, #0x0\n" - "movi v30.4s, #0x0\n" - "movi v31.4s, #0x0\n" + "movi v16.4s, #0x0\n" + "ldr a0q, [%[a_ptr0]]\n" + "movi v17.4s, #0x0\n" + "ldr b0q, [%[b_ptr]]\n" + "movi v18.4s, #0x0\n" + "ldr b1q, [%[b_ptr], #16]\n" + "movi v19.4s, #0x0\n" + "ldr b2q, [%[b_ptr], #32]\n" + "movi v20.4s, #0x0\n" + "ldr b3q, [%[b_ptr], #48]\n" + "movi v21.4s, #0x0\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "ldr a1q, [%[a_ptr1]]\n" + "movi v22.4s, #0x0\n" + "ldr a2q, [%[a_ptr2]]\n" + "movi v23.4s, #0x0\n" + "ldr a3q, [%[a_ptr3]]\n" + "movi v24.4s, #0x0\n" + "ldr b0aq, [%[b_ptr]]\n" + "movi v25.4s, #0x0\n" + "ldr b1aq, [%[b_ptr], #16]\n" + "movi v26.4s, #0x0\n" + "ldr b2aq, [%[b_ptr], #32]\n" + "cbz %w[beta0], 5f\n" + "movi v27.4s, #0x0\n" + "movi v28.4s, #0x0\n" + "movi v29.4s, #0x0\n" + "movi v30.4s, #0x0\n" + "movi v31.4s, #0x0\n" // Skip if no complete loops. - "cbz %w[loops], 4f\n" - "b 1f\n" + "cbz %w[loops], 4f\n" + "b 1f\n" // If beta is non-zero, need to load and multiply by beta "5:\n" - "ld1r {v4.4s}, [%[betaptr]]\n" - "ldr q16, [%[c_ptr0]]\n" - "ldr q17, [%[c_ptr0], #16]\n" - "ldr q18, [%[c_ptr0], #32]\n" - "ldr q19, [%[c_ptr0], #48]\n" - - "ldr q20, [%[c_ptr1]]\n" - "fmul v16.4s, v16.4s, v4.4s\n" - "ldr q21, [%[c_ptr1], #16]\n" - "fmul v17.4s, v17.4s, v4.4s\n" - "ldr q22, [%[c_ptr1], #32]\n" - "fmul v18.4s, v18.4s, v4.4s\n" - "ldr q23, [%[c_ptr1], #48]\n" - "fmul v19.4s, v19.4s, v4.4s\n" - - "ldr q24, [%[c_ptr2]]\n" - "fmul v20.4s, v20.4s, v4.4s\n" - "ldr q25, [%[c_ptr2], #16]\n" - "fmul v21.4s, v21.4s, v4.4s\n" - "ldr q26, [%[c_ptr2], #32]\n" - "fmul v22.4s, v22.4s, v4.4s\n" - "ldr q27, [%[c_ptr2], #48]\n" - "fmul v23.4s, v23.4s, v4.4s\n" - - "ldr q28, [%[c_ptr3]]\n" - "fmul v24.4s, v24.4s, v4.4s\n" - "ldr q29, [%[c_ptr3], #16]\n" - "fmul v25.4s, v25.4s, v4.4s\n" - "ldr q30, [%[c_ptr3], #32]\n" - "fmul v26.4s, v26.4s, v4.4s\n" - "ldr q31, [%[c_ptr3], #48]\n" - "fmul v27.4s, v27.4s, v4.4s\n" - - "fmul v28.4s, v28.4s, v4.4s\n" - "fmul v29.4s, v29.4s, v4.4s\n" - "fmul v30.4s, v30.4s, v4.4s\n" - "fmul v31.4s, v31.4s, v4.4s\n" - - "cbz %w[loops], 4f\n" + "ld1r {v4.4s}, [%[betaptr]]\n" + "ldr q16, [%[c_ptr0]]\n" + "ldr q17, [%[c_ptr0], #16]\n" + "ldr q18, [%[c_ptr0], #32]\n" + "ldr q19, [%[c_ptr0], #48]\n" + + "ldr q20, [%[c_ptr1]]\n" + "fmul v16.4s, v16.4s, v4.4s\n" + "ldr q21, [%[c_ptr1], #16]\n" + "fmul v17.4s, v17.4s, v4.4s\n" + "ldr q22, [%[c_ptr1], #32]\n" + "fmul v18.4s, v18.4s, v4.4s\n" + "ldr q23, [%[c_ptr1], #48]\n" + "fmul v19.4s, v19.4s, v4.4s\n" + + "ldr q24, [%[c_ptr2]]\n" + "fmul v20.4s, v20.4s, v4.4s\n" + "ldr q25, [%[c_ptr2], #16]\n" + "fmul v21.4s, v21.4s, v4.4s\n" + "ldr q26, [%[c_ptr2], #32]\n" + "fmul v22.4s, v22.4s, v4.4s\n" + "ldr q27, [%[c_ptr2], #48]\n" + "fmul v23.4s, v23.4s, v4.4s\n" + + "ldr q28, [%[c_ptr3]]\n" + "fmul v24.4s, v24.4s, v4.4s\n" + "ldr q29, [%[c_ptr3], #16]\n" + "fmul v25.4s, v25.4s, v4.4s\n" + "ldr q30, [%[c_ptr3], #32]\n" + "fmul v26.4s, v26.4s, v4.4s\n" + "ldr q31, [%[c_ptr3], #48]\n" + "fmul v27.4s, v27.4s, v4.4s\n" + + "fmul v28.4s, v28.4s, v4.4s\n" + "fmul v29.4s, v29.4s, v4.4s\n" + "fmul v30.4s, v30.4s, v4.4s\n" + "fmul v31.4s, v31.4s, v4.4s\n" + + "cbz %w[loops], 4f\n" "1:\n" // Unroll 0 - "fmla v16.4s, bb0.4s, a0.s[0]\n" - "fmla v20.4s, bb0.4s, a1.s[0]\n" - "ldr b3aq, [%[b_ptr], #48]\n" - "fmla v24.4s, bb0.4s, a2.s[0]\n" - "add %[b_ptr], %[b_ptr], %[ldb]\n" - "fmla v28.4s, bb0.4s, a3.s[0]\n" - "ldr b0q, [%[b_ptr]]\n" - - "fmla v17.4s, bb1.4s, a0.s[0]\n" - "fmla v21.4s, bb1.4s, a1.s[0]\n" - "ldr a0aq, [%[a_ptr0], #16]\n" - "fmla v25.4s, bb1.4s, a2.s[0]\n" - "fmla v29.4s, bb1.4s, a3.s[0]\n" - "ldr b1q, [%[b_ptr], #16]\n" - - "fmla v18.4s, bb2.4s, a0.s[0]\n" - "fmla v22.4s, bb2.4s, a1.s[0]\n" - "ldr a1aq, [%[a_ptr1], #16]\n" - "fmla v26.4s, bb2.4s, a2.s[0]\n" - "fmla v30.4s, bb2.4s, a3.s[0]\n" - "ldr b2q, [%[b_ptr], #32]\n" - - "fmla v19.4s, bb3.4s, a0.s[0]\n" - "fmla v23.4s, bb3.4s, a1.s[0]\n" - "ldr a2aq, [%[a_ptr2], #16]\n" - "fmla v27.4s, bb3.4s, a2.s[0]\n" - "fmla v31.4s, bb3.4s, a3.s[0]\n" - "ldr b3q, [%[b_ptr], #48]\n" + "fmla v16.4s, bb0.4s, a0.s[0]\n" + "fmla v20.4s, bb0.4s, a1.s[0]\n" + "ldr b3aq, [%[b_ptr], #48]\n" + "fmla v24.4s, bb0.4s, a2.s[0]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v28.4s, bb0.4s, a3.s[0]\n" + "ldr b0q, [%[b_ptr]]\n" + + "fmla v17.4s, bb1.4s, a0.s[0]\n" + "fmla v21.4s, bb1.4s, a1.s[0]\n" + "ldr a0aq, [%[a_ptr0], #16]\n" + "fmla v25.4s, bb1.4s, a2.s[0]\n" + "fmla v29.4s, bb1.4s, a3.s[0]\n" + "ldr b1q, [%[b_ptr], #16]\n" + + "fmla v18.4s, bb2.4s, a0.s[0]\n" + "fmla v22.4s, bb2.4s, a1.s[0]\n" + "ldr a1aq, [%[a_ptr1], #16]\n" + "fmla v26.4s, bb2.4s, a2.s[0]\n" + "fmla v30.4s, bb2.4s, a3.s[0]\n" + "ldr b2q, [%[b_ptr], #32]\n" + + "fmla v19.4s, bb3.4s, a0.s[0]\n" + "fmla v23.4s, bb3.4s, a1.s[0]\n" + "ldr a2aq, [%[a_ptr2], #16]\n" + "fmla v27.4s, bb3.4s, a2.s[0]\n" + "fmla v31.4s, bb3.4s, a3.s[0]\n" + "ldr b3q, [%[b_ptr], #48]\n" // Unroll 1 - "fmla v16.4s, b0a.4s, a0.s[1]\n" - "add %[b_ptr], %[b_ptr], %[ldb]\n" - "fmla v20.4s, b0a.4s, a1.s[1]\n" - "ldr a3aq, [%[a_ptr3], #16]\n" - "fmla v24.4s, b0a.4s, a2.s[1]\n" - "fmla v28.4s, b0a.4s, a3.s[1]\n" - "ldr b0aq, [%[b_ptr]]\n" - - "fmla v17.4s, b1a.4s, a0.s[1]\n" - "fmla v21.4s, b1a.4s, a1.s[1]\n" - "subs %w[loops], %w[loops], #1\n" - "fmla v25.4s, b1a.4s, a2.s[1]\n" - "fmla v29.4s, b1a.4s, a3.s[1]\n" - "ldr b1aq, [%[b_ptr], #16]\n" - - "fmla v18.4s, b2a.4s, a0.s[1]\n" - "fmla v22.4s, b2a.4s, a1.s[1]\n" - "fmla v26.4s, b2a.4s, a2.s[1]\n" - "fmla v30.4s, b2a.4s, a3.s[1]\n" - "ldr b2aq, [%[b_ptr], #32]\n" - - "fmla v19.4s, b3a.4s, a0.s[1]\n" - "fmla v23.4s, b3a.4s, a1.s[1]\n" - "fmla v27.4s, b3a.4s, a2.s[1]\n" - "fmla v31.4s, b3a.4s, a3.s[1]\n" - "ldr b3aq, [%[b_ptr], #48]\n" + "fmla v16.4s, b0a.4s, a0.s[1]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v20.4s, b0a.4s, a1.s[1]\n" + "ldr a3aq, [%[a_ptr3], #16]\n" + "fmla v24.4s, b0a.4s, a2.s[1]\n" + "fmla v28.4s, b0a.4s, a3.s[1]\n" + "ldr b0aq, [%[b_ptr]]\n" + + "fmla v17.4s, b1a.4s, a0.s[1]\n" + "fmla v21.4s, b1a.4s, a1.s[1]\n" + "subs %w[loops], %w[loops], #1\n" + "fmla v25.4s, b1a.4s, a2.s[1]\n" + "fmla v29.4s, b1a.4s, a3.s[1]\n" + "ldr b1aq, [%[b_ptr], #16]\n" + + "fmla v18.4s, b2a.4s, a0.s[1]\n" + "fmla v22.4s, b2a.4s, a1.s[1]\n" + "fmla v26.4s, b2a.4s, a2.s[1]\n" + "fmla v30.4s, b2a.4s, a3.s[1]\n" + "ldr b2aq, [%[b_ptr], #32]\n" + + "fmla v19.4s, b3a.4s, a0.s[1]\n" + "fmla v23.4s, b3a.4s, a1.s[1]\n" + "fmla v27.4s, b3a.4s, a2.s[1]\n" + "fmla v31.4s, b3a.4s, a3.s[1]\n" + "ldr b3aq, [%[b_ptr], #48]\n" // Unroll 2 "fmla v16.4s, bb0.4s, a0.s[2]\n" @@ -273,173 +272,173 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo "ldr b3q, [%[b_ptr], #48]\n" // Unroll 3 - "fmla v16.4s, b0a.4s, a0.s[3]\n" - "fmla v20.4s, b0a.4s, a1.s[3]\n" - "add %[b_ptr], %[b_ptr], %[ldb]\n" - "fmla v24.4s, b0a.4s, a2.s[3]\n" - "fmla v28.4s, b0a.4s, a3.s[3]\n" - "ldr b0aq, [%[b_ptr]]\n" - - "fmla v17.4s, b1a.4s, a0.s[3]\n" - "fmla v21.4s, b1a.4s, a1.s[3]\n" - "fmla v25.4s, b1a.4s, a2.s[3]\n" - "fmla v29.4s, b1a.4s, a3.s[3]\n" - "ldr b1aq, [%[b_ptr], #16]\n" - - "fmla v18.4s, b2a.4s, a0.s[3]\n" - "fmla v22.4s, b2a.4s, a1.s[3]\n" - "fmla v26.4s, b2a.4s, a2.s[3]\n" - "fmla v30.4s, b2a.4s, a3.s[3]\n" - "ldr b2aq, [%[b_ptr], #32]\n" - - "fmla v19.4s, b3a.4s, a0.s[3]\n" - "fmla v23.4s, b3a.4s, a1.s[3]\n" - "ldr a0q, [%[a_ptr0]]\n" - "fmla v27.4s, b3a.4s, a2.s[3]\n" - "fmla v31.4s, b3a.4s, a3.s[3]\n" - "ldr b3aq, [%[b_ptr], #48]\n" + "fmla v16.4s, b0a.4s, a0.s[3]\n" + "fmla v20.4s, b0a.4s, a1.s[3]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v24.4s, b0a.4s, a2.s[3]\n" + "fmla v28.4s, b0a.4s, a3.s[3]\n" + "ldr b0aq, [%[b_ptr]]\n" + + "fmla v17.4s, b1a.4s, a0.s[3]\n" + "fmla v21.4s, b1a.4s, a1.s[3]\n" + "fmla v25.4s, b1a.4s, a2.s[3]\n" + "fmla v29.4s, b1a.4s, a3.s[3]\n" + "ldr b1aq, [%[b_ptr], #16]\n" + + "fmla v18.4s, b2a.4s, a0.s[3]\n" + "fmla v22.4s, b2a.4s, a1.s[3]\n" + "fmla v26.4s, b2a.4s, a2.s[3]\n" + "fmla v30.4s, b2a.4s, a3.s[3]\n" + "ldr b2aq, [%[b_ptr], #32]\n" + + "fmla v19.4s, b3a.4s, a0.s[3]\n" + "fmla v23.4s, b3a.4s, a1.s[3]\n" + "ldr a0q, [%[a_ptr0]]\n" + "fmla v27.4s, b3a.4s, a2.s[3]\n" + "fmla v31.4s, b3a.4s, a3.s[3]\n" + "ldr b3aq, [%[b_ptr], #48]\n" // Unroll 4 - "fmla v16.4s, bb0.4s, a0a.s[0]\n" - "fmla v20.4s, bb0.4s, a1a.s[0]\n" - "add %[b_ptr], %[b_ptr], %[ldb]\n" - "fmla v24.4s, bb0.4s, a2a.s[0]\n" - "fmla v28.4s, bb0.4s, a3a.s[0]\n" - "ldr b0q, [%[b_ptr]]\n" - - "fmla v17.4s, bb1.4s, a0a.s[0]\n" - "fmla v21.4s, bb1.4s, a1a.s[0]\n" - "ldr a1q, [%[a_ptr1]]\n" - "fmla v25.4s, bb1.4s, a2a.s[0]\n" - "fmla v29.4s, bb1.4s, a3a.s[0]\n" - "ldr b1q, [%[b_ptr], #16]\n" - - "fmla v18.4s, bb2.4s, a0a.s[0]\n" - "fmla v22.4s, bb2.4s, a1a.s[0]\n" - "ldr a2q, [%[a_ptr2]]\n" - "fmla v26.4s, bb2.4s, a2a.s[0]\n" - "fmla v30.4s, bb2.4s, a3a.s[0]\n" - "ldr b2q, [%[b_ptr], #32]\n" - - "fmla v19.4s, bb3.4s, a0a.s[0]\n" - "fmla v23.4s, bb3.4s, a1a.s[0]\n" - "ldr a3q, [%[a_ptr3]]\n" - "fmla v27.4s, bb3.4s, a2a.s[0]\n" - "fmla v31.4s, bb3.4s, a3a.s[0]\n" - "ldr b3q, [%[b_ptr], #48]\n" + "fmla v16.4s, bb0.4s, a0a.s[0]\n" + "fmla v20.4s, bb0.4s, a1a.s[0]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v24.4s, bb0.4s, a2a.s[0]\n" + "fmla v28.4s, bb0.4s, a3a.s[0]\n" + "ldr b0q, [%[b_ptr]]\n" + + "fmla v17.4s, bb1.4s, a0a.s[0]\n" + "fmla v21.4s, bb1.4s, a1a.s[0]\n" + "ldr a1q, [%[a_ptr1]]\n" + "fmla v25.4s, bb1.4s, a2a.s[0]\n" + "fmla v29.4s, bb1.4s, a3a.s[0]\n" + "ldr b1q, [%[b_ptr], #16]\n" + + "fmla v18.4s, bb2.4s, a0a.s[0]\n" + "fmla v22.4s, bb2.4s, a1a.s[0]\n" + "ldr a2q, [%[a_ptr2]]\n" + "fmla v26.4s, bb2.4s, a2a.s[0]\n" + "fmla v30.4s, bb2.4s, a3a.s[0]\n" + "ldr b2q, [%[b_ptr], #32]\n" + + "fmla v19.4s, bb3.4s, a0a.s[0]\n" + "fmla v23.4s, bb3.4s, a1a.s[0]\n" + "ldr a3q, [%[a_ptr3]]\n" + "fmla v27.4s, bb3.4s, a2a.s[0]\n" + "fmla v31.4s, bb3.4s, a3a.s[0]\n" + "ldr b3q, [%[b_ptr], #48]\n" // Unroll 5 - "fmla v16.4s, b0a.4s, a0a.s[1]\n" - "fmla v20.4s, b0a.4s, a1a.s[1]\n" - "add %[b_ptr], %[b_ptr], %[ldb]\n" - "fmla v24.4s, b0a.4s, a2a.s[1]\n" - "fmla v28.4s, b0a.4s, a3a.s[1]\n" - "ldr b0aq, [%[b_ptr]]\n" - - "fmla v17.4s, b1a.4s, a0a.s[1]\n" - "fmla v21.4s, b1a.4s, a1a.s[1]\n" - "fmla v25.4s, b1a.4s, a2a.s[1]\n" - "fmla v29.4s, b1a.4s, a3a.s[1]\n" - "ldr b1aq, [%[b_ptr], #16]\n" - - "fmla v18.4s, b2a.4s, a0a.s[1]\n" - "fmla v22.4s, b2a.4s, a1a.s[1]\n" - "fmla v26.4s, b2a.4s, a2a.s[1]\n" - "fmla v30.4s, b2a.4s, a3a.s[1]\n" - "ldr b2aq, [%[b_ptr], #32]\n" - - "fmla v19.4s, b3a.4s, a0a.s[1]\n" - "fmla v23.4s, b3a.4s, a1a.s[1]\n" - "fmla v27.4s, b3a.4s, a2a.s[1]\n" - "fmla v31.4s, b3a.4s, a3a.s[1]\n" - "ldr b3aq, [%[b_ptr], #48]\n" + "fmla v16.4s, b0a.4s, a0a.s[1]\n" + "fmla v20.4s, b0a.4s, a1a.s[1]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v24.4s, b0a.4s, a2a.s[1]\n" + "fmla v28.4s, b0a.4s, a3a.s[1]\n" + "ldr b0aq, [%[b_ptr]]\n" + + "fmla v17.4s, b1a.4s, a0a.s[1]\n" + "fmla v21.4s, b1a.4s, a1a.s[1]\n" + "fmla v25.4s, b1a.4s, a2a.s[1]\n" + "fmla v29.4s, b1a.4s, a3a.s[1]\n" + "ldr b1aq, [%[b_ptr], #16]\n" + + "fmla v18.4s, b2a.4s, a0a.s[1]\n" + "fmla v22.4s, b2a.4s, a1a.s[1]\n" + "fmla v26.4s, b2a.4s, a2a.s[1]\n" + "fmla v30.4s, b2a.4s, a3a.s[1]\n" + "ldr b2aq, [%[b_ptr], #32]\n" + + "fmla v19.4s, b3a.4s, a0a.s[1]\n" + "fmla v23.4s, b3a.4s, a1a.s[1]\n" + "fmla v27.4s, b3a.4s, a2a.s[1]\n" + "fmla v31.4s, b3a.4s, a3a.s[1]\n" + "ldr b3aq, [%[b_ptr], #48]\n" // Unroll 6 - "fmla v16.4s, bb0.4s, a0a.s[2]\n" - "fmla v20.4s, bb0.4s, a1a.s[2]\n" - "add %[b_ptr], %[b_ptr], %[ldb]\n" - "fmla v24.4s, bb0.4s, a2a.s[2]\n" - "fmla v28.4s, bb0.4s, a3a.s[2]\n" - "ldr b0q, [%[b_ptr]]\n" - - "fmla v17.4s, bb1.4s, a0a.s[2]\n" - "fmla v21.4s, bb1.4s, a1a.s[2]\n" - "fmla v25.4s, bb1.4s, a2a.s[2]\n" - "fmla v29.4s, bb1.4s, a3a.s[2]\n" - "ldr b1q, [%[b_ptr], #16]\n" - - "fmla v18.4s, bb2.4s, a0a.s[2]\n" - "fmla v22.4s, bb2.4s, a1a.s[2]\n" - "fmla v26.4s, bb2.4s, a2a.s[2]\n" - "fmla v30.4s, bb2.4s, a3a.s[2]\n" - "ldr b2q, [%[b_ptr], #32]\n" - - "fmla v19.4s, bb3.4s, a0a.s[2]\n" - "fmla v23.4s, bb3.4s, a1a.s[2]\n" - "fmla v27.4s, bb3.4s, a2a.s[2]\n" - "fmla v31.4s, bb3.4s, a3a.s[2]\n" - "ldr b3q, [%[b_ptr], #48]\n" + "fmla v16.4s, bb0.4s, a0a.s[2]\n" + "fmla v20.4s, bb0.4s, a1a.s[2]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v24.4s, bb0.4s, a2a.s[2]\n" + "fmla v28.4s, bb0.4s, a3a.s[2]\n" + "ldr b0q, [%[b_ptr]]\n" + + "fmla v17.4s, bb1.4s, a0a.s[2]\n" + "fmla v21.4s, bb1.4s, a1a.s[2]\n" + "fmla v25.4s, bb1.4s, a2a.s[2]\n" + "fmla v29.4s, bb1.4s, a3a.s[2]\n" + "ldr b1q, [%[b_ptr], #16]\n" + + "fmla v18.4s, bb2.4s, a0a.s[2]\n" + "fmla v22.4s, bb2.4s, a1a.s[2]\n" + "fmla v26.4s, bb2.4s, a2a.s[2]\n" + "fmla v30.4s, bb2.4s, a3a.s[2]\n" + "ldr b2q, [%[b_ptr], #32]\n" + + "fmla v19.4s, bb3.4s, a0a.s[2]\n" + "fmla v23.4s, bb3.4s, a1a.s[2]\n" + "fmla v27.4s, bb3.4s, a2a.s[2]\n" + "fmla v31.4s, bb3.4s, a3a.s[2]\n" + "ldr b3q, [%[b_ptr], #48]\n" // Unroll 7 - "fmla v16.4s, b0a.4s, a0a.s[3]\n" - "fmla v20.4s, b0a.4s, a1a.s[3]\n" - "add %[b_ptr], %[b_ptr], %[ldb]\n" - "fmla v24.4s, b0a.4s, a2a.s[3]\n" - "fmla v28.4s, b0a.4s, a3a.s[3]\n" - "ldr b0aq, [%[b_ptr]]\n" - - "fmla v17.4s, b1a.4s, a0a.s[3]\n" - "fmla v21.4s, b1a.4s, a1a.s[3]\n" - "fmla v25.4s, b1a.4s, a2a.s[3]\n" - "fmla v29.4s, b1a.4s, a3a.s[3]\n" - "ldr b1aq, [%[b_ptr], #16]\n" - - "fmla v18.4s, b2a.4s, a0a.s[3]\n" - "fmla v22.4s, b2a.4s, a1a.s[3]\n" - "fmla v26.4s, b2a.4s, a2a.s[3]\n" - "fmla v30.4s, b2a.4s, a3a.s[3]\n" - "ldr b2aq, [%[b_ptr], #32]\n" - - "fmla v19.4s, b3a.4s, a0a.s[3]\n" - "fmla v23.4s, b3a.4s, a1a.s[3]\n" - "fmla v27.4s, b3a.4s, a2a.s[3]\n" - "fmla v31.4s, b3a.4s, a3a.s[3]\n" - "bne 1b\n" + "fmla v16.4s, b0a.4s, a0a.s[3]\n" + "fmla v20.4s, b0a.4s, a1a.s[3]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v24.4s, b0a.4s, a2a.s[3]\n" + "fmla v28.4s, b0a.4s, a3a.s[3]\n" + "ldr b0aq, [%[b_ptr]]\n" + + "fmla v17.4s, b1a.4s, a0a.s[3]\n" + "fmla v21.4s, b1a.4s, a1a.s[3]\n" + "fmla v25.4s, b1a.4s, a2a.s[3]\n" + "fmla v29.4s, b1a.4s, a3a.s[3]\n" + "ldr b1aq, [%[b_ptr], #16]\n" + + "fmla v18.4s, b2a.4s, a0a.s[3]\n" + "fmla v22.4s, b2a.4s, a1a.s[3]\n" + "fmla v26.4s, b2a.4s, a2a.s[3]\n" + "fmla v30.4s, b2a.4s, a3a.s[3]\n" + "ldr b2aq, [%[b_ptr], #32]\n" + + "fmla v19.4s, b3a.4s, a0a.s[3]\n" + "fmla v23.4s, b3a.4s, a1a.s[3]\n" + "fmla v27.4s, b3a.4s, a2a.s[3]\n" + "fmla v31.4s, b3a.4s, a3a.s[3]\n" + "bne 1b\n" // Skip to here "4:\n" // Detached final iteration // Unroll 0 - "fmla v16.4s, bb0.4s, a0.s[0]\n" - "fmla v20.4s, bb0.4s, a1.s[0]\n" - "ldr b3aq, [%[b_ptr], #48]\n" - "fmla v24.4s, bb0.4s, a2.s[0]\n" - "add %[b_ptr], %[b_ptr], %[ldb]\n" - "fmla v28.4s, bb0.4s, a3.s[0]\n" - "ldr b0q, [%[b_ptr]]\n" - - "fmla v17.4s, bb1.4s, a0.s[0]\n" - "cbnz %w[oddk], 2f\n" // Deal with odd K before we load a0a - "fmla v21.4s, bb1.4s, a1.s[0]\n" - "ldr a0aq, [%[a_ptr0], #16]\n" - "fmla v25.4s, bb1.4s, a2.s[0]\n" - "fmla v29.4s, bb1.4s, a3.s[0]\n" - "ldr b1q, [%[b_ptr], #16]\n" - - "fmla v18.4s, bb2.4s, a0.s[0]\n" - "fmla v22.4s, bb2.4s, a1.s[0]\n" - "ldr a1aq, [%[a_ptr1], #16]\n" - "fmla v26.4s, bb2.4s, a2.s[0]\n" - "fmla v30.4s, bb2.4s, a3.s[0]\n" - "ldr b2q, [%[b_ptr], #32]\n" - - "fmla v19.4s, bb3.4s, a0.s[0]\n" - "fmla v23.4s, bb3.4s, a1.s[0]\n" - "ldr a2aq, [%[a_ptr2], #16]\n" - "fmla v27.4s, bb3.4s, a2.s[0]\n" - "fmla v31.4s, bb3.4s, a3.s[0]\n" - "ldr b3q, [%[b_ptr], #48]\n" + "fmla v16.4s, bb0.4s, a0.s[0]\n" + "fmla v20.4s, bb0.4s, a1.s[0]\n" + "ldr b3aq, [%[b_ptr], #48]\n" + "fmla v24.4s, bb0.4s, a2.s[0]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v28.4s, bb0.4s, a3.s[0]\n" + "ldr b0q, [%[b_ptr]]\n" + + "fmla v17.4s, bb1.4s, a0.s[0]\n" + "cbnz %w[oddk], 2f\n" // Deal with odd K before we load a0a + "fmla v21.4s, bb1.4s, a1.s[0]\n" + "ldr a0aq, [%[a_ptr0], #16]\n" + "fmla v25.4s, bb1.4s, a2.s[0]\n" + "fmla v29.4s, bb1.4s, a3.s[0]\n" + "ldr b1q, [%[b_ptr], #16]\n" + + "fmla v18.4s, bb2.4s, a0.s[0]\n" + "fmla v22.4s, bb2.4s, a1.s[0]\n" + "ldr a1aq, [%[a_ptr1], #16]\n" + "fmla v26.4s, bb2.4s, a2.s[0]\n" + "fmla v30.4s, bb2.4s, a3.s[0]\n" + "ldr b2q, [%[b_ptr], #32]\n" + + "fmla v19.4s, bb3.4s, a0.s[0]\n" + "fmla v23.4s, bb3.4s, a1.s[0]\n" + "ldr a2aq, [%[a_ptr2], #16]\n" + "fmla v27.4s, bb3.4s, a2.s[0]\n" + "fmla v31.4s, bb3.4s, a3.s[0]\n" + "ldr b3q, [%[b_ptr], #48]\n" // Unroll 1 "fmla v16.4s, b0a.4s, a0.s[1]\n" @@ -473,394 +472,394 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo "ldr b3aq, [%[b_ptr], #48]\n" // Unroll 2 - "fmla v16.4s, bb0.4s, a0.s[2]\n" - "fmla v20.4s, bb0.4s, a1.s[2]\n" - "add %[b_ptr], %[b_ptr], %[ldb]\n" - "fmla v24.4s, bb0.4s, a2.s[2]\n" - "fmla v28.4s, bb0.4s, a3.s[2]\n" - "ldr b0q, [%[b_ptr]]\n" - - "fmla v17.4s, bb1.4s, a0.s[2]\n" - "fmla v21.4s, bb1.4s, a1.s[2]\n" - "fmla v25.4s, bb1.4s, a2.s[2]\n" - "fmla v29.4s, bb1.4s, a3.s[2]\n" - "ldr b1q, [%[b_ptr], #16]\n" - - "fmla v18.4s, bb2.4s, a0.s[2]\n" - "fmla v22.4s, bb2.4s, a1.s[2]\n" - "fmla v26.4s, bb2.4s, a2.s[2]\n" - "fmla v30.4s, bb2.4s, a3.s[2]\n" - "ldr b2q, [%[b_ptr], #32]\n" - - "fmla v19.4s, bb3.4s, a0.s[2]\n" - "fmla v23.4s, bb3.4s, a1.s[2]\n" - "fmla v27.4s, bb3.4s, a2.s[2]\n" - "fmla v31.4s, bb3.4s, a3.s[2]\n" - "ldr b3q, [%[b_ptr], #48]\n" + "fmla v16.4s, bb0.4s, a0.s[2]\n" + "fmla v20.4s, bb0.4s, a1.s[2]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v24.4s, bb0.4s, a2.s[2]\n" + "fmla v28.4s, bb0.4s, a3.s[2]\n" + "ldr b0q, [%[b_ptr]]\n" + + "fmla v17.4s, bb1.4s, a0.s[2]\n" + "fmla v21.4s, bb1.4s, a1.s[2]\n" + "fmla v25.4s, bb1.4s, a2.s[2]\n" + "fmla v29.4s, bb1.4s, a3.s[2]\n" + "ldr b1q, [%[b_ptr], #16]\n" + + "fmla v18.4s, bb2.4s, a0.s[2]\n" + "fmla v22.4s, bb2.4s, a1.s[2]\n" + "fmla v26.4s, bb2.4s, a2.s[2]\n" + "fmla v30.4s, bb2.4s, a3.s[2]\n" + "ldr b2q, [%[b_ptr], #32]\n" + + "fmla v19.4s, bb3.4s, a0.s[2]\n" + "fmla v23.4s, bb3.4s, a1.s[2]\n" + "fmla v27.4s, bb3.4s, a2.s[2]\n" + "fmla v31.4s, bb3.4s, a3.s[2]\n" + "ldr b3q, [%[b_ptr], #48]\n" // Unroll 3 - "fmla v16.4s, b0a.4s, a0.s[3]\n" - "fmla v20.4s, b0a.4s, a1.s[3]\n" - "add %[b_ptr], %[b_ptr], %[ldb]\n" - "fmla v24.4s, b0a.4s, a2.s[3]\n" - "fmla v28.4s, b0a.4s, a3.s[3]\n" - "ldr b0aq, [%[b_ptr]]\n" - - "fmla v17.4s, b1a.4s, a0.s[3]\n" - "fmla v21.4s, b1a.4s, a1.s[3]\n" - "fmla v25.4s, b1a.4s, a2.s[3]\n" - "fmla v29.4s, b1a.4s, a3.s[3]\n" - "ldr b1aq, [%[b_ptr], #16]\n" - - "fmla v18.4s, b2a.4s, a0.s[3]\n" - "fmla v22.4s, b2a.4s, a1.s[3]\n" - "fmla v26.4s, b2a.4s, a2.s[3]\n" - "fmla v30.4s, b2a.4s, a3.s[3]\n" - "ldr b2aq, [%[b_ptr], #32]\n" - - "fmla v19.4s, b3a.4s, a0.s[3]\n" - "fmla v23.4s, b3a.4s, a1.s[3]\n" - "fmla v27.4s, b3a.4s, a2.s[3]\n" - "fmla v31.4s, b3a.4s, a3.s[3]\n" - "ldr b3aq, [%[b_ptr], #48]\n" + "fmla v16.4s, b0a.4s, a0.s[3]\n" + "fmla v20.4s, b0a.4s, a1.s[3]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v24.4s, b0a.4s, a2.s[3]\n" + "fmla v28.4s, b0a.4s, a3.s[3]\n" + "ldr b0aq, [%[b_ptr]]\n" + + "fmla v17.4s, b1a.4s, a0.s[3]\n" + "fmla v21.4s, b1a.4s, a1.s[3]\n" + "fmla v25.4s, b1a.4s, a2.s[3]\n" + "fmla v29.4s, b1a.4s, a3.s[3]\n" + "ldr b1aq, [%[b_ptr], #16]\n" + + "fmla v18.4s, b2a.4s, a0.s[3]\n" + "fmla v22.4s, b2a.4s, a1.s[3]\n" + "fmla v26.4s, b2a.4s, a2.s[3]\n" + "fmla v30.4s, b2a.4s, a3.s[3]\n" + "ldr b2aq, [%[b_ptr], #32]\n" + + "fmla v19.4s, b3a.4s, a0.s[3]\n" + "fmla v23.4s, b3a.4s, a1.s[3]\n" + "fmla v27.4s, b3a.4s, a2.s[3]\n" + "fmla v31.4s, b3a.4s, a3.s[3]\n" + "ldr b3aq, [%[b_ptr], #48]\n" // Unroll 4 - "fmla v16.4s, bb0.4s, a0a.s[0]\n" - "fmla v20.4s, bb0.4s, a1a.s[0]\n" - "add %[b_ptr], %[b_ptr], %[ldb]\n" - "fmla v24.4s, bb0.4s, a2a.s[0]\n" - "fmla v28.4s, bb0.4s, a3a.s[0]\n" - "ldr b0q, [%[b_ptr]]\n" - - "fmla v17.4s, bb1.4s, a0a.s[0]\n" - "fmla v21.4s, bb1.4s, a1a.s[0]\n" - "fmla v25.4s, bb1.4s, a2a.s[0]\n" - "fmla v29.4s, bb1.4s, a3a.s[0]\n" - "ldr b1q, [%[b_ptr], #16]\n" - - "fmla v18.4s, bb2.4s, a0a.s[0]\n" - "fmla v22.4s, bb2.4s, a1a.s[0]\n" - "fmla v26.4s, bb2.4s, a2a.s[0]\n" - "fmla v30.4s, bb2.4s, a3a.s[0]\n" - "ldr b2q, [%[b_ptr], #32]\n" - - "fmla v19.4s, bb3.4s, a0a.s[0]\n" - "fmla v23.4s, bb3.4s, a1a.s[0]\n" - "fmla v27.4s, bb3.4s, a2a.s[0]\n" - "fmla v31.4s, bb3.4s, a3a.s[0]\n" - "ldr b3q, [%[b_ptr], #48]\n" + "fmla v16.4s, bb0.4s, a0a.s[0]\n" + "fmla v20.4s, bb0.4s, a1a.s[0]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v24.4s, bb0.4s, a2a.s[0]\n" + "fmla v28.4s, bb0.4s, a3a.s[0]\n" + "ldr b0q, [%[b_ptr]]\n" + + "fmla v17.4s, bb1.4s, a0a.s[0]\n" + "fmla v21.4s, bb1.4s, a1a.s[0]\n" + "fmla v25.4s, bb1.4s, a2a.s[0]\n" + "fmla v29.4s, bb1.4s, a3a.s[0]\n" + "ldr b1q, [%[b_ptr], #16]\n" + + "fmla v18.4s, bb2.4s, a0a.s[0]\n" + "fmla v22.4s, bb2.4s, a1a.s[0]\n" + "fmla v26.4s, bb2.4s, a2a.s[0]\n" + "fmla v30.4s, bb2.4s, a3a.s[0]\n" + "ldr b2q, [%[b_ptr], #32]\n" + + "fmla v19.4s, bb3.4s, a0a.s[0]\n" + "fmla v23.4s, bb3.4s, a1a.s[0]\n" + "fmla v27.4s, bb3.4s, a2a.s[0]\n" + "fmla v31.4s, bb3.4s, a3a.s[0]\n" + "ldr b3q, [%[b_ptr], #48]\n" // Unroll 5 - "fmla v16.4s, b0a.4s, a0a.s[1]\n" - "fmla v20.4s, b0a.4s, a1a.s[1]\n" - "add %[b_ptr], %[b_ptr], %[ldb]\n" - "fmla v24.4s, b0a.4s, a2a.s[1]\n" - "fmla v28.4s, b0a.4s, a3a.s[1]\n" - "ldr b0aq, [%[b_ptr]]\n" - - "fmla v17.4s, b1a.4s, a0a.s[1]\n" - "fmla v21.4s, b1a.4s, a1a.s[1]\n" - "fmla v25.4s, b1a.4s, a2a.s[1]\n" - "fmla v29.4s, b1a.4s, a3a.s[1]\n" - "ldr b1aq, [%[b_ptr], #16]\n" - - "fmla v18.4s, b2a.4s, a0a.s[1]\n" - "fmla v22.4s, b2a.4s, a1a.s[1]\n" - "fmla v26.4s, b2a.4s, a2a.s[1]\n" - "fmla v30.4s, b2a.4s, a3a.s[1]\n" - "ldr b2aq, [%[b_ptr], #32]\n" - - "fmla v19.4s, b3a.4s, a0a.s[1]\n" - "fmla v23.4s, b3a.4s, a1a.s[1]\n" - "fmla v27.4s, b3a.4s, a2a.s[1]\n" - "fmla v31.4s, b3a.4s, a3a.s[1]\n" - "ldr b3aq, [%[b_ptr], #48]\n" + "fmla v16.4s, b0a.4s, a0a.s[1]\n" + "fmla v20.4s, b0a.4s, a1a.s[1]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v24.4s, b0a.4s, a2a.s[1]\n" + "fmla v28.4s, b0a.4s, a3a.s[1]\n" + "ldr b0aq, [%[b_ptr]]\n" + + "fmla v17.4s, b1a.4s, a0a.s[1]\n" + "fmla v21.4s, b1a.4s, a1a.s[1]\n" + "fmla v25.4s, b1a.4s, a2a.s[1]\n" + "fmla v29.4s, b1a.4s, a3a.s[1]\n" + "ldr b1aq, [%[b_ptr], #16]\n" + + "fmla v18.4s, b2a.4s, a0a.s[1]\n" + "fmla v22.4s, b2a.4s, a1a.s[1]\n" + "fmla v26.4s, b2a.4s, a2a.s[1]\n" + "fmla v30.4s, b2a.4s, a3a.s[1]\n" + "ldr b2aq, [%[b_ptr], #32]\n" + + "fmla v19.4s, b3a.4s, a0a.s[1]\n" + "fmla v23.4s, b3a.4s, a1a.s[1]\n" + "fmla v27.4s, b3a.4s, a2a.s[1]\n" + "fmla v31.4s, b3a.4s, a3a.s[1]\n" + "ldr b3aq, [%[b_ptr], #48]\n" // Unroll 6 - "fmla v16.4s, bb0.4s, a0a.s[2]\n" - "fmla v20.4s, bb0.4s, a1a.s[2]\n" - "add %[b_ptr], %[b_ptr], %[ldb]\n" - "fmla v24.4s, bb0.4s, a2a.s[2]\n" - "fmla v28.4s, bb0.4s, a3a.s[2]\n" - - "fmla v17.4s, bb1.4s, a0a.s[2]\n" - "fmla v21.4s, bb1.4s, a1a.s[2]\n" - "fmla v25.4s, bb1.4s, a2a.s[2]\n" - "fmla v29.4s, bb1.4s, a3a.s[2]\n" - - "fmla v18.4s, bb2.4s, a0a.s[2]\n" - "fmla v22.4s, bb2.4s, a1a.s[2]\n" - "fmla v26.4s, bb2.4s, a2a.s[2]\n" - "fmla v30.4s, bb2.4s, a3a.s[2]\n" - - "fmla v19.4s, bb3.4s, a0a.s[2]\n" - "fmla v23.4s, bb3.4s, a1a.s[2]\n" - "fmla v27.4s, bb3.4s, a2a.s[2]\n" - "fmla v31.4s, bb3.4s, a3a.s[2]\n" + "fmla v16.4s, bb0.4s, a0a.s[2]\n" + "fmla v20.4s, bb0.4s, a1a.s[2]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v24.4s, bb0.4s, a2a.s[2]\n" + "fmla v28.4s, bb0.4s, a3a.s[2]\n" + + "fmla v17.4s, bb1.4s, a0a.s[2]\n" + "fmla v21.4s, bb1.4s, a1a.s[2]\n" + "fmla v25.4s, bb1.4s, a2a.s[2]\n" + "fmla v29.4s, bb1.4s, a3a.s[2]\n" + + "fmla v18.4s, bb2.4s, a0a.s[2]\n" + "fmla v22.4s, bb2.4s, a1a.s[2]\n" + "fmla v26.4s, bb2.4s, a2a.s[2]\n" + "fmla v30.4s, bb2.4s, a3a.s[2]\n" + + "fmla v19.4s, bb3.4s, a0a.s[2]\n" + "fmla v23.4s, bb3.4s, a1a.s[2]\n" + "fmla v27.4s, bb3.4s, a2a.s[2]\n" + "fmla v31.4s, bb3.4s, a3a.s[2]\n" // Unroll 7 - "fmla v16.4s, b0a.4s, a0a.s[3]\n" - "fmla v17.4s, b1a.4s, a0a.s[3]\n" - "fmla v18.4s, b2a.4s, a0a.s[3]\n" - "fmla v19.4s, b3a.4s, a0a.s[3]\n" - "cbnz %w[odds], 6f\n" - - "fmla v20.4s, b0a.4s, a1a.s[3]\n" - "str q16, [%[c_ptr0]]\n" - "fmla v21.4s, b1a.4s, a1a.s[3]\n" - "str q17, [%[c_ptr0], #16]\n" - "fmla v22.4s, b2a.4s, a1a.s[3]\n" - "str q18, [%[c_ptr0], #32]\n" - "fmla v23.4s, b3a.4s, a1a.s[3]\n" - "str q19, [%[c_ptr0], #48]\n" - - "fmla v24.4s, b0a.4s, a2a.s[3]\n" - "str q20, [%[c_ptr1]]\n" - "fmla v25.4s, b1a.4s, a2a.s[3]\n" - "str q21, [%[c_ptr1], #16]\n" - "fmla v26.4s, b2a.4s, a2a.s[3]\n" - "str q22, [%[c_ptr1], #32]\n" - "fmla v27.4s, b3a.4s, a2a.s[3]\n" - "str q23, [%[c_ptr1], #48]\n" - - "fmla v28.4s, b0a.4s, a3a.s[3]\n" - "str q24, [%[c_ptr2]]\n" - "fmla v29.4s, b1a.4s, a3a.s[3]\n" - "str q25, [%[c_ptr2], #16]\n" - "fmla v30.4s, b2a.4s, a3a.s[3]\n" - "str q26, [%[c_ptr2], #32]\n" - "fmla v31.4s, b3a.4s, a3a.s[3]\n" - "str q27, [%[c_ptr2], #48]\n" - "b 3f\n" + "fmla v16.4s, b0a.4s, a0a.s[3]\n" + "fmla v17.4s, b1a.4s, a0a.s[3]\n" + "fmla v18.4s, b2a.4s, a0a.s[3]\n" + "fmla v19.4s, b3a.4s, a0a.s[3]\n" + "cbnz %w[odds], 6f\n" + + "fmla v20.4s, b0a.4s, a1a.s[3]\n" + "str q16, [%[c_ptr0]]\n" + "fmla v21.4s, b1a.4s, a1a.s[3]\n" + "str q17, [%[c_ptr0], #16]\n" + "fmla v22.4s, b2a.4s, a1a.s[3]\n" + "str q18, [%[c_ptr0], #32]\n" + "fmla v23.4s, b3a.4s, a1a.s[3]\n" + "str q19, [%[c_ptr0], #48]\n" + + "fmla v24.4s, b0a.4s, a2a.s[3]\n" + "str q20, [%[c_ptr1]]\n" + "fmla v25.4s, b1a.4s, a2a.s[3]\n" + "str q21, [%[c_ptr1], #16]\n" + "fmla v26.4s, b2a.4s, a2a.s[3]\n" + "str q22, [%[c_ptr1], #32]\n" + "fmla v27.4s, b3a.4s, a2a.s[3]\n" + "str q23, [%[c_ptr1], #48]\n" + + "fmla v28.4s, b0a.4s, a3a.s[3]\n" + "str q24, [%[c_ptr2]]\n" + "fmla v29.4s, b1a.4s, a3a.s[3]\n" + "str q25, [%[c_ptr2], #16]\n" + "fmla v30.4s, b2a.4s, a3a.s[3]\n" + "str q26, [%[c_ptr2], #32]\n" + "fmla v31.4s, b3a.4s, a3a.s[3]\n" + "str q27, [%[c_ptr2], #48]\n" + "b 3f\n" // Odd K case: Just do 4 more. "2:\n" - "fmla v21.4s, bb1.4s, a1.s[0]\n" - "add %[a_ptr0], %[a_ptr0], #16\n" - "fmla v25.4s, bb1.4s, a2.s[0]\n" - "add %[a_ptr1], %[a_ptr1], #16\n" - "fmla v29.4s, bb1.4s, a3.s[0]\n" - "ldr b1q, [%[b_ptr], #16]\n" - - "fmla v18.4s, bb2.4s, a0.s[0]\n" - "add %[a_ptr2], %[a_ptr2], #16\n" - "fmla v22.4s, bb2.4s, a1.s[0]\n" - "add %[a_ptr3], %[a_ptr3], #16\n" - "fmla v26.4s, bb2.4s, a2.s[0]\n" - "fmla v30.4s, bb2.4s, a3.s[0]\n" - "ldr b2q, [%[b_ptr], #32]\n" - - "fmla v19.4s, bb3.4s, a0.s[0]\n" - "fmla v23.4s, bb3.4s, a1.s[0]\n" - "fmla v27.4s, bb3.4s, a2.s[0]\n" - "fmla v31.4s, bb3.4s, a3.s[0]\n" - "ldr b3q, [%[b_ptr], #48]\n" + "fmla v21.4s, bb1.4s, a1.s[0]\n" + "add %[a_ptr0], %[a_ptr0], #16\n" + "fmla v25.4s, bb1.4s, a2.s[0]\n" + "add %[a_ptr1], %[a_ptr1], #16\n" + "fmla v29.4s, bb1.4s, a3.s[0]\n" + "ldr b1q, [%[b_ptr], #16]\n" + + "fmla v18.4s, bb2.4s, a0.s[0]\n" + "add %[a_ptr2], %[a_ptr2], #16\n" + "fmla v22.4s, bb2.4s, a1.s[0]\n" + "add %[a_ptr3], %[a_ptr3], #16\n" + "fmla v26.4s, bb2.4s, a2.s[0]\n" + "fmla v30.4s, bb2.4s, a3.s[0]\n" + "ldr b2q, [%[b_ptr], #32]\n" + + "fmla v19.4s, bb3.4s, a0.s[0]\n" + "fmla v23.4s, bb3.4s, a1.s[0]\n" + "fmla v27.4s, bb3.4s, a2.s[0]\n" + "fmla v31.4s, bb3.4s, a3.s[0]\n" + "ldr b3q, [%[b_ptr], #48]\n" // Unroll 1 - "fmla v16.4s, b0a.4s, a0.s[1]\n" - "add %[b_ptr], %[b_ptr], %[ldb]\n" - "fmla v20.4s, b0a.4s, a1.s[1]\n" - "fmla v24.4s, b0a.4s, a2.s[1]\n" - "fmla v28.4s, b0a.4s, a3.s[1]\n" - "ldr b0aq, [%[b_ptr]]\n" - - "fmla v17.4s, b1a.4s, a0.s[1]\n" - "fmla v21.4s, b1a.4s, a1.s[1]\n" - "fmla v25.4s, b1a.4s, a2.s[1]\n" - "fmla v29.4s, b1a.4s, a3.s[1]\n" - "ldr b1aq, [%[b_ptr], #16]\n" - - "fmla v18.4s, b2a.4s, a0.s[1]\n" - "fmla v22.4s, b2a.4s, a1.s[1]\n" - "fmla v26.4s, b2a.4s, a2.s[1]\n" - "fmla v30.4s, b2a.4s, a3.s[1]\n" - "ldr b2aq, [%[b_ptr], #32]\n" - - "fmla v19.4s, b3a.4s, a0.s[1]\n" - "fmla v23.4s, b3a.4s, a1.s[1]\n" - "fmla v27.4s, b3a.4s, a2.s[1]\n" - "fmla v31.4s, b3a.4s, a3.s[1]\n" - "ldr b3aq, [%[b_ptr], #48]\n" + "fmla v16.4s, b0a.4s, a0.s[1]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v20.4s, b0a.4s, a1.s[1]\n" + "fmla v24.4s, b0a.4s, a2.s[1]\n" + "fmla v28.4s, b0a.4s, a3.s[1]\n" + "ldr b0aq, [%[b_ptr]]\n" + + "fmla v17.4s, b1a.4s, a0.s[1]\n" + "fmla v21.4s, b1a.4s, a1.s[1]\n" + "fmla v25.4s, b1a.4s, a2.s[1]\n" + "fmla v29.4s, b1a.4s, a3.s[1]\n" + "ldr b1aq, [%[b_ptr], #16]\n" + + "fmla v18.4s, b2a.4s, a0.s[1]\n" + "fmla v22.4s, b2a.4s, a1.s[1]\n" + "fmla v26.4s, b2a.4s, a2.s[1]\n" + "fmla v30.4s, b2a.4s, a3.s[1]\n" + "ldr b2aq, [%[b_ptr], #32]\n" + + "fmla v19.4s, b3a.4s, a0.s[1]\n" + "fmla v23.4s, b3a.4s, a1.s[1]\n" + "fmla v27.4s, b3a.4s, a2.s[1]\n" + "fmla v31.4s, b3a.4s, a3.s[1]\n" + "ldr b3aq, [%[b_ptr], #48]\n" // Unroll 2 - "fmla v16.4s, bb0.4s, a0.s[2]\n" - "add %[b_ptr], %[b_ptr], %[ldb]\n" - "fmla v20.4s, bb0.4s, a1.s[2]\n" - "fmla v24.4s, bb0.4s, a2.s[2]\n" - "fmla v28.4s, bb0.4s, a3.s[2]\n" - - "fmla v17.4s, bb1.4s, a0.s[2]\n" - "fmla v21.4s, bb1.4s, a1.s[2]\n" - "fmla v25.4s, bb1.4s, a2.s[2]\n" - "fmla v29.4s, bb1.4s, a3.s[2]\n" - - "fmla v18.4s, bb2.4s, a0.s[2]\n" - "fmla v22.4s, bb2.4s, a1.s[2]\n" - "fmla v26.4s, bb2.4s, a2.s[2]\n" - "fmla v30.4s, bb2.4s, a3.s[2]\n" - - "fmla v19.4s, bb3.4s, a0.s[2]\n" - "fmla v23.4s, bb3.4s, a1.s[2]\n" - "fmla v27.4s, bb3.4s, a2.s[2]\n" - "fmla v31.4s, bb3.4s, a3.s[2]\n" + "fmla v16.4s, bb0.4s, a0.s[2]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v20.4s, bb0.4s, a1.s[2]\n" + "fmla v24.4s, bb0.4s, a2.s[2]\n" + "fmla v28.4s, bb0.4s, a3.s[2]\n" + + "fmla v17.4s, bb1.4s, a0.s[2]\n" + "fmla v21.4s, bb1.4s, a1.s[2]\n" + "fmla v25.4s, bb1.4s, a2.s[2]\n" + "fmla v29.4s, bb1.4s, a3.s[2]\n" + + "fmla v18.4s, bb2.4s, a0.s[2]\n" + "fmla v22.4s, bb2.4s, a1.s[2]\n" + "fmla v26.4s, bb2.4s, a2.s[2]\n" + "fmla v30.4s, bb2.4s, a3.s[2]\n" + + "fmla v19.4s, bb3.4s, a0.s[2]\n" + "fmla v23.4s, bb3.4s, a1.s[2]\n" + "fmla v27.4s, bb3.4s, a2.s[2]\n" + "fmla v31.4s, bb3.4s, a3.s[2]\n" // Unroll 3 - "fmla v16.4s, b0a.4s, a0.s[3]\n" - "fmla v17.4s, b1a.4s, a0.s[3]\n" - "fmla v18.4s, b2a.4s, a0.s[3]\n" - "fmla v19.4s, b3a.4s, a0.s[3]\n" - "cbnz %w[odds], 7f\n" - - "fmla v20.4s, b0a.4s, a1.s[3]\n" - "str q16, [%[c_ptr0]]\n" - "fmla v21.4s, b1a.4s, a1.s[3]\n" - "str q17, [%[c_ptr0], #16]\n" - "fmla v22.4s, b2a.4s, a1.s[3]\n" - "str q18, [%[c_ptr0], #32]\n" - "fmla v23.4s, b3a.4s, a1.s[3]\n" - "str q19, [%[c_ptr0], #48]\n" - - "fmla v24.4s, b0a.4s, a2.s[3]\n" - "str q20, [%[c_ptr1]]\n" - "fmla v25.4s, b1a.4s, a2.s[3]\n" - "str q21, [%[c_ptr1], #16]\n" - "fmla v26.4s, b2a.4s, a2.s[3]\n" - "str q22, [%[c_ptr1], #32]\n" - "fmla v27.4s, b3a.4s, a2.s[3]\n" - "str q23, [%[c_ptr1], #48]\n" - - "fmla v28.4s, b0a.4s, a3.s[3]\n" - "str q24, [%[c_ptr2]]\n" - "fmla v29.4s, b1a.4s, a3.s[3]\n" - "str q25, [%[c_ptr2], #16]\n" - "fmla v30.4s, b2a.4s, a3.s[3]\n" - "str q26, [%[c_ptr2], #32]\n" - "fmla v31.4s, b3a.4s, a3.s[3]\n" - "str q27, [%[c_ptr2], #48]\n" - "b 3f\n" + "fmla v16.4s, b0a.4s, a0.s[3]\n" + "fmla v17.4s, b1a.4s, a0.s[3]\n" + "fmla v18.4s, b2a.4s, a0.s[3]\n" + "fmla v19.4s, b3a.4s, a0.s[3]\n" + "cbnz %w[odds], 7f\n" + + "fmla v20.4s, b0a.4s, a1.s[3]\n" + "str q16, [%[c_ptr0]]\n" + "fmla v21.4s, b1a.4s, a1.s[3]\n" + "str q17, [%[c_ptr0], #16]\n" + "fmla v22.4s, b2a.4s, a1.s[3]\n" + "str q18, [%[c_ptr0], #32]\n" + "fmla v23.4s, b3a.4s, a1.s[3]\n" + "str q19, [%[c_ptr0], #48]\n" + + "fmla v24.4s, b0a.4s, a2.s[3]\n" + "str q20, [%[c_ptr1]]\n" + "fmla v25.4s, b1a.4s, a2.s[3]\n" + "str q21, [%[c_ptr1], #16]\n" + "fmla v26.4s, b2a.4s, a2.s[3]\n" + "str q22, [%[c_ptr1], #32]\n" + "fmla v27.4s, b3a.4s, a2.s[3]\n" + "str q23, [%[c_ptr1], #48]\n" + + "fmla v28.4s, b0a.4s, a3.s[3]\n" + "str q24, [%[c_ptr2]]\n" + "fmla v29.4s, b1a.4s, a3.s[3]\n" + "str q25, [%[c_ptr2], #16]\n" + "fmla v30.4s, b2a.4s, a3.s[3]\n" + "str q26, [%[c_ptr2], #32]\n" + "fmla v31.4s, b3a.4s, a3.s[3]\n" + "str q27, [%[c_ptr2], #48]\n" + "b 3f\n" // "Odd ones" - lead in from even "6:\n" - "fmla v20.4s, b0a.4s, a1a.s[3]\n" - "fmla v21.4s, b1a.4s, a1a.s[3]\n" - "ldr b0q, [%[b_ptr]]\n" - "fmla v22.4s, b2a.4s, a1a.s[3]\n" - "subs %w[odds], %w[odds], #1\n" - "fmla v23.4s, b3a.4s, a1a.s[3]\n" - "ldr b1q, [%[b_ptr], #16]\n" - - "fmla v24.4s, b0a.4s, a2a.s[3]\n" - "fmla v25.4s, b1a.4s, a2a.s[3]\n" - "ldr b2q, [%[b_ptr], #32]\n" - "fmla v26.4s, b2a.4s, a2a.s[3]\n" - "fmla v27.4s, b3a.4s, a2a.s[3]\n" - "ldr b3q, [%[b_ptr], #48]\n" - - "fmla v28.4s, b0a.4s, a3a.s[3]\n" - "ld1r {a0.4s}, [%[a_ptr0]], #4\n" - "fmla v29.4s, b1a.4s, a3a.s[3]\n" - "fmla v30.4s, b2a.4s, a3a.s[3]\n" - "ld1r {a1.4s}, [%[a_ptr1]], #4\n" - "fmla v31.4s, b3a.4s, a3a.s[3]\n" - - "fmla v16.4s, bb0.4s, a0.4s\n" - "beq 9f\n" - "b 8f\n" + "fmla v20.4s, b0a.4s, a1a.s[3]\n" + "fmla v21.4s, b1a.4s, a1a.s[3]\n" + "ldr b0q, [%[b_ptr]]\n" + "fmla v22.4s, b2a.4s, a1a.s[3]\n" + "subs %w[odds], %w[odds], #1\n" + "fmla v23.4s, b3a.4s, a1a.s[3]\n" + "ldr b1q, [%[b_ptr], #16]\n" + + "fmla v24.4s, b0a.4s, a2a.s[3]\n" + "fmla v25.4s, b1a.4s, a2a.s[3]\n" + "ldr b2q, [%[b_ptr], #32]\n" + "fmla v26.4s, b2a.4s, a2a.s[3]\n" + "fmla v27.4s, b3a.4s, a2a.s[3]\n" + "ldr b3q, [%[b_ptr], #48]\n" + + "fmla v28.4s, b0a.4s, a3a.s[3]\n" + "ld1r {a0.4s}, [%[a_ptr0]], #4\n" + "fmla v29.4s, b1a.4s, a3a.s[3]\n" + "fmla v30.4s, b2a.4s, a3a.s[3]\n" + "ld1r {a1.4s}, [%[a_ptr1]], #4\n" + "fmla v31.4s, b3a.4s, a3a.s[3]\n" + + "fmla v16.4s, bb0.4s, a0.4s\n" + "beq 9f\n" + "b 8f\n" // "Odd ones" - lead in from odd "7:\n" - "fmla v20.4s, b0a.4s, a1.s[3]\n" - "subs %w[odds], %w[odds], #1\n" - "fmla v21.4s, b1a.4s, a1.s[3]\n" - "ldr b0q, [%[b_ptr]]\n" - "fmla v22.4s, b2a.4s, a1.s[3]\n" - "fmla v23.4s, b3a.4s, a1.s[3]\n" - "ldr b1q, [%[b_ptr], #16]\n" - - "fmla v24.4s, b0a.4s, a2.s[3]\n" - "fmla v25.4s, b1a.4s, a2.s[3]\n" - "ldr b2q, [%[b_ptr], #32]\n" - "fmla v26.4s, b2a.4s, a2.s[3]\n" - "fmla v27.4s, b3a.4s, a2.s[3]\n" - "ldr b3q, [%[b_ptr], #48]\n" - - "fmla v28.4s, b0a.4s, a3.s[3]\n" - "ld1r {a0.4s}, [%[a_ptr0]], #4\n" - "fmla v29.4s, b1a.4s, a3.s[3]\n" - "fmla v30.4s, b2a.4s, a3.s[3]\n" - "ld1r {a1.4s}, [%[a_ptr1]], #4\n" - "fmla v31.4s, b3a.4s, a3.s[3]\n" - - "fmla v16.4s, bb0.4s, a0.4s\n" - "beq 9f\n" + "fmla v20.4s, b0a.4s, a1.s[3]\n" + "subs %w[odds], %w[odds], #1\n" + "fmla v21.4s, b1a.4s, a1.s[3]\n" + "ldr b0q, [%[b_ptr]]\n" + "fmla v22.4s, b2a.4s, a1.s[3]\n" + "fmla v23.4s, b3a.4s, a1.s[3]\n" + "ldr b1q, [%[b_ptr], #16]\n" + + "fmla v24.4s, b0a.4s, a2.s[3]\n" + "fmla v25.4s, b1a.4s, a2.s[3]\n" + "ldr b2q, [%[b_ptr], #32]\n" + "fmla v26.4s, b2a.4s, a2.s[3]\n" + "fmla v27.4s, b3a.4s, a2.s[3]\n" + "ldr b3q, [%[b_ptr], #48]\n" + + "fmla v28.4s, b0a.4s, a3.s[3]\n" + "ld1r {a0.4s}, [%[a_ptr0]], #4\n" + "fmla v29.4s, b1a.4s, a3.s[3]\n" + "fmla v30.4s, b2a.4s, a3.s[3]\n" + "ld1r {a1.4s}, [%[a_ptr1]], #4\n" + "fmla v31.4s, b3a.4s, a3.s[3]\n" + + "fmla v16.4s, bb0.4s, a0.4s\n" + "beq 9f\n" // "Odd ones" - loop "8:\n" - "fmla v17.4s, bb1.4s, a0.4s\n" - "ld1r {a2.4s}, [%[a_ptr2]], #4\n" - "fmla v18.4s, bb2.4s, a0.4s\n" - "add %[b_ptr], %[b_ptr], %[ldb]\n" - "fmla v19.4s, bb3.4s, a0.4s\n" - "ld1r {a3.4s}, [%[a_ptr3]], #4\n" - - "fmla v20.4s, bb0.4s, a1.4s\n" - "subs %w[odds], %w[odds], #1\n" - "fmla v21.4s, bb1.4s, a1.4s\n" - "ld1r {a0.4s}, [%[a_ptr0]], #4\n" - "fmla v22.4s, bb2.4s, a1.4s\n" - "fmla v23.4s, bb3.4s, a1.4s\n" - "ld1r {a1.4s}, [%[a_ptr1]], #4\n" - - "fmla v24.4s, bb0.4s, a2.4s\n" - "fmla v28.4s, bb0.4s, a3.4s\n" - "ldr b0q, [%[b_ptr]]\n" - "fmla v25.4s, bb1.4s, a2.4s\n" - "fmla v29.4s, bb1.4s, a3.4s\n" - "ldr b1q, [%[b_ptr], #16]\n" - - "fmla v26.4s, bb2.4s, a2.4s\n" - "fmla v30.4s, bb2.4s, a3.4s\n" - "ldr b2q, [%[b_ptr], #32]\n" - "fmla v27.4s, bb3.4s, a2.4s\n" - "fmla v31.4s, bb3.4s, a3.4s\n" - "ldr b3q, [%[b_ptr], #48]\n" - "fmla v16.4s, bb0.4s, a0.4s\n" - "bne 8b\n" + "fmla v17.4s, bb1.4s, a0.4s\n" + "ld1r {a2.4s}, [%[a_ptr2]], #4\n" + "fmla v18.4s, bb2.4s, a0.4s\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v19.4s, bb3.4s, a0.4s\n" + "ld1r {a3.4s}, [%[a_ptr3]], #4\n" + + "fmla v20.4s, bb0.4s, a1.4s\n" + "subs %w[odds], %w[odds], #1\n" + "fmla v21.4s, bb1.4s, a1.4s\n" + "ld1r {a0.4s}, [%[a_ptr0]], #4\n" + "fmla v22.4s, bb2.4s, a1.4s\n" + "fmla v23.4s, bb3.4s, a1.4s\n" + "ld1r {a1.4s}, [%[a_ptr1]], #4\n" + + "fmla v24.4s, bb0.4s, a2.4s\n" + "fmla v28.4s, bb0.4s, a3.4s\n" + "ldr b0q, [%[b_ptr]]\n" + "fmla v25.4s, bb1.4s, a2.4s\n" + "fmla v29.4s, bb1.4s, a3.4s\n" + "ldr b1q, [%[b_ptr], #16]\n" + + "fmla v26.4s, bb2.4s, a2.4s\n" + "fmla v30.4s, bb2.4s, a3.4s\n" + "ldr b2q, [%[b_ptr], #32]\n" + "fmla v27.4s, bb3.4s, a2.4s\n" + "fmla v31.4s, bb3.4s, a3.4s\n" + "ldr b3q, [%[b_ptr], #48]\n" + "fmla v16.4s, bb0.4s, a0.4s\n" + "bne 8b\n" // "Odd ones" - detached final iteration "9:\n" - "fmla v17.4s, bb1.4s, a0.4s\n" - "ld1r {a2.4s}, [%[a_ptr2]], #4\n" - "fmla v18.4s, bb2.4s, a0.4s\n" - "fmla v19.4s, bb3.4s, a0.4s\n" - "ld1r {a3.4s}, [%[a_ptr3]], #4\n" - - "fmla v20.4s, bb0.4s, a1.4s\n" - "str q16, [%[c_ptr0]]\n" - "fmla v21.4s, bb1.4s, a1.4s\n" - "str q17, [%[c_ptr0], #16]\n" - "fmla v22.4s, bb2.4s, a1.4s\n" - "str q18, [%[c_ptr0], #32]\n" - "fmla v23.4s, bb3.4s, a1.4s\n" - "str q19, [%[c_ptr0], #48]\n" - - "fmla v24.4s, bb0.4s, a2.4s\n" - "str q20, [%[c_ptr1]]\n" - "fmla v25.4s, bb1.4s, a2.4s\n" - "str q21, [%[c_ptr1], #16]\n" - "fmla v26.4s, bb2.4s, a2.4s\n" - "str q22, [%[c_ptr1], #32]\n" - "fmla v27.4s, bb3.4s, a2.4s\n" - "str q23, [%[c_ptr1], #48]\n" - - "fmla v28.4s, bb0.4s, a3.4s\n" - "str q24, [%[c_ptr2]]\n" - "fmla v29.4s, bb1.4s, a3.4s\n" - "str q25, [%[c_ptr2], #16]\n" - "fmla v30.4s, bb2.4s, a3.4s\n" - "str q26, [%[c_ptr2], #32]\n" - "fmla v31.4s, bb3.4s, a3.4s\n" - "str q27, [%[c_ptr2], #48]\n" + "fmla v17.4s, bb1.4s, a0.4s\n" + "ld1r {a2.4s}, [%[a_ptr2]], #4\n" + "fmla v18.4s, bb2.4s, a0.4s\n" + "fmla v19.4s, bb3.4s, a0.4s\n" + "ld1r {a3.4s}, [%[a_ptr3]], #4\n" + + "fmla v20.4s, bb0.4s, a1.4s\n" + "str q16, [%[c_ptr0]]\n" + "fmla v21.4s, bb1.4s, a1.4s\n" + "str q17, [%[c_ptr0], #16]\n" + "fmla v22.4s, bb2.4s, a1.4s\n" + "str q18, [%[c_ptr0], #32]\n" + "fmla v23.4s, bb3.4s, a1.4s\n" + "str q19, [%[c_ptr0], #48]\n" + + "fmla v24.4s, bb0.4s, a2.4s\n" + "str q20, [%[c_ptr1]]\n" + "fmla v25.4s, bb1.4s, a2.4s\n" + "str q21, [%[c_ptr1], #16]\n" + "fmla v26.4s, bb2.4s, a2.4s\n" + "str q22, [%[c_ptr1], #32]\n" + "fmla v27.4s, bb3.4s, a2.4s\n" + "str q23, [%[c_ptr1], #48]\n" + + "fmla v28.4s, bb0.4s, a3.4s\n" + "str q24, [%[c_ptr2]]\n" + "fmla v29.4s, bb1.4s, a3.4s\n" + "str q25, [%[c_ptr2], #16]\n" + "fmla v30.4s, bb2.4s, a3.4s\n" + "str q26, [%[c_ptr2], #32]\n" + "fmla v31.4s, bb3.4s, a3.4s\n" + "str q27, [%[c_ptr2], #48]\n" "3:\n" "str q28, [%[c_ptr3]]\n" diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_pretransposed.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_pretransposed.hpp index c89514f98e..a73bc76b5d 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_pretransposed.hpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_pretransposed.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -25,14 +25,13 @@ #ifdef __aarch64__ -namespace arm_gemm -{ +namespace arm_gemm { + // Actual kernel implementations void a64_sgemv_pretransposed(const float *, int, const float *, float *, float, int, int); // Pretransposed SGEMV strategy class. -class sgemv_pretransposed -{ +class sgemv_pretransposed { public: typedef float operand_type; typedef float result_type; @@ -47,19 +46,17 @@ public: * terms of this standard arrangement, so if the A matrix is in fact the * B matrix from a GEMM call, the sense of the transpose needs to be * reversed. */ - static const int A_interleave = 32; - static const int A_block = 1; - static const bool A_transpose = false; + static const int A_interleave = 32; + static const int A_block = 1; + static const bool A_transpose = false; /* Kernel blocking parameters */ static const int out_width = 32; - static const int k_unroll = 1; + static const int k_unroll = 1; kern_type kernel = a64_sgemv_pretransposed; - sgemv_pretransposed(const CPUInfo *ci) - { - } + sgemv_pretransposed(const CPUInfo *ci) { } }; } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_pretransposed/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_pretransposed/generic.cpp index 290759822a..165e0a60da 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_pretransposed/generic.cpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_pretransposed/generic.cpp @@ -30,15 +30,13 @@ #include "../../asmlib.hpp" #include "../../utils.hpp" -namespace arm_gemm -{ -void a64_sgemv_pretransposed(const float *A, int lda, const float *X, float *Y, float beta, int M, int N) -{ - const bool beta0 = (beta == 0.0f); - const bool beta1 = (beta == 1.0f); - - for(int x = 0; x < N; x += 32) - { +namespace arm_gemm { + +void a64_sgemv_pretransposed(const float *A, int lda, const float *X, float *Y, float beta, int M, int N) { + const bool beta0 = (beta==0.0f); + const bool beta1 = (beta==1.0f); + + for (int x=0; x= 8) - { - int k = (M / 8) - 1; - x0 = vld1q_f32(x_ptr); - - __asm __volatile( - "ldr q2, [%[a_ptr], #0]\n" - "ldr q3, [%[a_ptr], #16]\n" - "ldr q4, [%[a_ptr], #32]\n" - "ldr q5, [%[a_ptr], #48]\n" - "ldr q6, [%[a_ptr], #64]\n" - "ldr q7, [%[a_ptr], #80]\n" - "ldr q8, [%[a_ptr], #96]\n" - "ldr q9, [%[a_ptr], #112]\n" - "ldr q10, [%[a_ptr], #128]\n" - "ldr q11, [%[a_ptr], #144]\n" - "ldr q12, [%[a_ptr], #160]\n" - "ldr q13, [%[a_ptr], #176]\n" - "ldr q14, [%[a_ptr], #192]\n" - "ldr q15, [%[a_ptr], #208]\n" - "ldr q16, [%[a_ptr], #224]\n" - "ldr q17, [%[a_ptr], #240]\n" - "ldr q18, [%[a_ptr], #256]\n" - "ldr q19, [%[a_ptr], #272]\n" - "ldr q20, [%[a_ptr], #288]\n" - "ldr q21, [%[a_ptr], #304]\n" - "ldr q22, [%[a_ptr], #320]\n" - "ldr q23, [%[a_ptr], #336]\n" ASM_PREFETCH("[%[a_ptr], #384]") + if (M>=8) { + int k = (M/8)-1; + x0 = vld1q_f32(x_ptr); + + __asm __volatile ( + "ldr q2, [%[a_ptr], #0]\n" + "ldr q3, [%[a_ptr], #16]\n" + "ldr q4, [%[a_ptr], #32]\n" + "ldr q5, [%[a_ptr], #48]\n" + "ldr q6, [%[a_ptr], #64]\n" + "ldr q7, [%[a_ptr], #80]\n" + "ldr q8, [%[a_ptr], #96]\n" + "ldr q9, [%[a_ptr], #112]\n" + "ldr q10, [%[a_ptr], #128]\n" + "ldr q11, [%[a_ptr], #144]\n" + "ldr q12, [%[a_ptr], #160]\n" + "ldr q13, [%[a_ptr], #176]\n" + "ldr q14, [%[a_ptr], #192]\n" + "ldr q15, [%[a_ptr], #208]\n" + "ldr q16, [%[a_ptr], #224]\n" + "ldr q17, [%[a_ptr], #240]\n" + "ldr q18, [%[a_ptr], #256]\n" + "ldr q19, [%[a_ptr], #272]\n" + "ldr q20, [%[a_ptr], #288]\n" + "ldr q21, [%[a_ptr], #304]\n" + "ldr q22, [%[a_ptr], #320]\n" + "ldr q23, [%[a_ptr], #336]\n" + ASM_PREFETCH("[%[a_ptr], #384]") ASM_PREFETCH("[%[a_ptr], #448]") ASM_PREFETCH("[%[a_ptr], #512]") ASM_PREFETCH("[%[a_ptr], #576]") @@ -284,363 +218,377 @@ void a64_sgemv_pretransposed(const float *A, int lda, const float *X, float *Y, ASM_PREFETCH("[%[a_ptr], #1856]") ASM_PREFETCH("[%[a_ptr], #1920]") ASM_PREFETCH("[%[a_ptr], #1984]") - "add %[a_ptr], %[a_ptr], #352\n" + "add %[a_ptr], %[a_ptr], #352\n" - "cbz %w[k], 2f\n" + "cbz %w[k], 2f\n" "1:\n" // Unroll 0 - "fmla %[r0].4s, v2.4s, %[x0].s[0]\n" - "ldr %q[x0a], [%[x_ptr], #16]\n" - "fmla %[r1].4s, v3.4s, %[x0].s[0]\n" - "ldr q3, [%[a_ptr], #0]\n" - "subs %w[k], %w[k], #1\n" - "fmla %[r2].4s, v4.4s, %[x0].s[0]\n" - "ldr q4, [%[a_ptr], #16]\n" - "fmla %[r3].4s, v5.4s, %[x0].s[0]\n" - "ldr q5, [%[a_ptr], #32]\n" - "add %[x_ptr], %[x_ptr], #32\n" ASM_PREFETCH("[%[a_ptr], #1664]") - "fmla %[r4].4s, v6.4s, %[x0].s[0]\n" - "ldr q6, [%[a_ptr], #48]\n" - "fmla %[r5].4s, v7.4s, %[x0].s[0]\n" - "ldr q7, [%[a_ptr], #64]\n" - "fmla %[r6].4s, v8.4s, %[x0].s[0]\n" - "ldr q8, [%[a_ptr], #80]\n" - "fmla %[r7].4s, v9.4s, %[x0].s[0]\n" - "ldr q9, [%[a_ptr], #96]\n" ASM_PREFETCH("[%[a_ptr], #1728]") + "fmla %[r0].4s, v2.4s, %[x0].s[0]\n" + "ldr %q[x0a], [%[x_ptr], #16]\n" + "fmla %[r1].4s, v3.4s, %[x0].s[0]\n" + "ldr q3, [%[a_ptr], #0]\n" + "subs %w[k], %w[k], #1\n" + "fmla %[r2].4s, v4.4s, %[x0].s[0]\n" + "ldr q4, [%[a_ptr], #16]\n" + "fmla %[r3].4s, v5.4s, %[x0].s[0]\n" + "ldr q5, [%[a_ptr], #32]\n" + "add %[x_ptr], %[x_ptr], #32\n" + ASM_PREFETCH("[%[a_ptr], #1664]") + "fmla %[r4].4s, v6.4s, %[x0].s[0]\n" + "ldr q6, [%[a_ptr], #48]\n" + "fmla %[r5].4s, v7.4s, %[x0].s[0]\n" + "ldr q7, [%[a_ptr], #64]\n" + "fmla %[r6].4s, v8.4s, %[x0].s[0]\n" + "ldr q8, [%[a_ptr], #80]\n" + "fmla %[r7].4s, v9.4s, %[x0].s[0]\n" + "ldr q9, [%[a_ptr], #96]\n" + ASM_PREFETCH("[%[a_ptr], #1728]") // Unroll 1 - "fmla %[r0].4s, v10.4s, %[x0].s[1]\n" - "ldr q10, [%[a_ptr], #112]\n" - "fmla %[r1].4s, v11.4s, %[x0].s[1]\n" - "ldr q11, [%[a_ptr], #128]\n" - "fmla %[r2].4s, v12.4s, %[x0].s[1]\n" - "ldr q12, [%[a_ptr], #144]\n" - "fmla %[r3].4s, v13.4s, %[x0].s[1]\n" - "ldr q13, [%[a_ptr], #160]\n" ASM_PREFETCH("[%[a_ptr], #1792]") - "fmla %[r4].4s, v14.4s, %[x0].s[1]\n" - "ldr q14, [%[a_ptr], #176]\n" - "fmla %[r5].4s, v15.4s, %[x0].s[1]\n" - "ldr q15, [%[a_ptr], #192]\n" - "fmla %[r6].4s, v16.4s, %[x0].s[1]\n" - "ldr q16, [%[a_ptr], #208]\n" - "fmla %[r7].4s, v17.4s, %[x0].s[1]\n" - "ldr q17, [%[a_ptr], #224]\n" ASM_PREFETCH("[%[a_ptr], #1856]") + "fmla %[r0].4s, v10.4s, %[x0].s[1]\n" + "ldr q10, [%[a_ptr], #112]\n" + "fmla %[r1].4s, v11.4s, %[x0].s[1]\n" + "ldr q11, [%[a_ptr], #128]\n" + "fmla %[r2].4s, v12.4s, %[x0].s[1]\n" + "ldr q12, [%[a_ptr], #144]\n" + "fmla %[r3].4s, v13.4s, %[x0].s[1]\n" + "ldr q13, [%[a_ptr], #160]\n" + ASM_PREFETCH("[%[a_ptr], #1792]") + "fmla %[r4].4s, v14.4s, %[x0].s[1]\n" + "ldr q14, [%[a_ptr], #176]\n" + "fmla %[r5].4s, v15.4s, %[x0].s[1]\n" + "ldr q15, [%[a_ptr], #192]\n" + "fmla %[r6].4s, v16.4s, %[x0].s[1]\n" + "ldr q16, [%[a_ptr], #208]\n" + "fmla %[r7].4s, v17.4s, %[x0].s[1]\n" + "ldr q17, [%[a_ptr], #224]\n" + ASM_PREFETCH("[%[a_ptr], #1856]") // Unroll 2 - "fmla %[r0].4s, v18.4s, %[x0].s[2]\n" - "ldr q18, [%[a_ptr], #240]\n" - "fmla %[r1].4s, v19.4s, %[x0].s[2]\n" - "ldr q19, [%[a_ptr], #256]\n" - "fmla %[r2].4s, v20.4s, %[x0].s[2]\n" - "ldr q20, [%[a_ptr], #272]\n" - "fmla %[r3].4s, v21.4s, %[x0].s[2]\n" - "ldr q21, [%[a_ptr], #288]\n" ASM_PREFETCH("[%[a_ptr], #1920]") - "fmla %[r4].4s, v22.4s, %[x0].s[2]\n" - "ldr q22, [%[a_ptr], #304]\n" - "fmla %[r5].4s, v23.4s, %[x0].s[2]\n" - "ldr q23, [%[a_ptr], #320]\n" - "fmla %[r6].4s, v3.4s, %[x0].s[2]\n" - "ldr q2, [%[a_ptr], #336]\n" - "ldr q3, [%[a_ptr], #352]\n" - "fmla %[r7].4s, v4.4s, %[x0].s[2]\n" - "ldr q4, [%[a_ptr], #368]\n" ASM_PREFETCH("[%[a_ptr], #1984]") + "fmla %[r0].4s, v18.4s, %[x0].s[2]\n" + "ldr q18, [%[a_ptr], #240]\n" + "fmla %[r1].4s, v19.4s, %[x0].s[2]\n" + "ldr q19, [%[a_ptr], #256]\n" + "fmla %[r2].4s, v20.4s, %[x0].s[2]\n" + "ldr q20, [%[a_ptr], #272]\n" + "fmla %[r3].4s, v21.4s, %[x0].s[2]\n" + "ldr q21, [%[a_ptr], #288]\n" + ASM_PREFETCH("[%[a_ptr], #1920]") + "fmla %[r4].4s, v22.4s, %[x0].s[2]\n" + "ldr q22, [%[a_ptr], #304]\n" + "fmla %[r5].4s, v23.4s, %[x0].s[2]\n" + "ldr q23, [%[a_ptr], #320]\n" + "fmla %[r6].4s, v3.4s, %[x0].s[2]\n" + "ldr q2, [%[a_ptr], #336]\n" + "ldr q3, [%[a_ptr], #352]\n" + "fmla %[r7].4s, v4.4s, %[x0].s[2]\n" + "ldr q4, [%[a_ptr], #368]\n" + ASM_PREFETCH("[%[a_ptr], #1984]") // Unroll 3 - "fmla %[r0].4s, v5.4s, %[x0].s[3]\n" - "ldr q5, [%[a_ptr], #384]\n" - "fmla %[r1].4s, v6.4s, %[x0].s[3]\n" - "ldr q6, [%[a_ptr], #400]\n" - "fmla %[r2].4s, v7.4s, %[x0].s[3]\n" - "ldr q7, [%[a_ptr], #416]\n" - "fmla %[r3].4s, v8.4s, %[x0].s[3]\n" ASM_PREFETCH("[%[a_ptr], #2048]") - "ldr q8, [%[a_ptr], #432]\n" - "fmla %[r4].4s, v9.4s, %[x0].s[3]\n" - "ldr q9, [%[a_ptr], #448]\n" - "fmla %[r5].4s, v10.4s, %[x0].s[3]\n" - "ldr q10, [%[a_ptr], #464]\n" - "fmla %[r6].4s, v11.4s, %[x0].s[3]\n" - "ldr q11, [%[a_ptr], #480]\n" - "fmla %[r7].4s, v12.4s, %[x0].s[3]\n" - "ldr q12, [%[a_ptr], #496]\n" ASM_PREFETCH("[%[a_ptr], #2112]") + "fmla %[r0].4s, v5.4s, %[x0].s[3]\n" + "ldr q5, [%[a_ptr], #384]\n" + "fmla %[r1].4s, v6.4s, %[x0].s[3]\n" + "ldr q6, [%[a_ptr], #400]\n" + "fmla %[r2].4s, v7.4s, %[x0].s[3]\n" + "ldr q7, [%[a_ptr], #416]\n" + "fmla %[r3].4s, v8.4s, %[x0].s[3]\n" + ASM_PREFETCH("[%[a_ptr], #2048]") + "ldr q8, [%[a_ptr], #432]\n" + "fmla %[r4].4s, v9.4s, %[x0].s[3]\n" + "ldr q9, [%[a_ptr], #448]\n" + "fmla %[r5].4s, v10.4s, %[x0].s[3]\n" + "ldr q10, [%[a_ptr], #464]\n" + "fmla %[r6].4s, v11.4s, %[x0].s[3]\n" + "ldr q11, [%[a_ptr], #480]\n" + "fmla %[r7].4s, v12.4s, %[x0].s[3]\n" + "ldr q12, [%[a_ptr], #496]\n" + ASM_PREFETCH("[%[a_ptr], #2112]") // Unroll 4 - "fmla %[r0].4s, v13.4s, %[x0a].s[0]\n" - "ldr %q[x0], [%[x_ptr]]\n" - "fmla %[r1].4s, v14.4s, %[x0a].s[0]\n" - "ldr q14, [%[a_ptr], #512]\n" - "fmla %[r2].4s, v15.4s, %[x0a].s[0]\n" - "ldr q15, [%[a_ptr], #528]\n" - "fmla %[r3].4s, v16.4s, %[x0a].s[0]\n" ASM_PREFETCH("[%[a_ptr], #2176]") - "ldr q16, [%[a_ptr], #544]\n" - "fmla %[r4].4s, v17.4s, %[x0a].s[0]\n" - "ldr q17, [%[a_ptr], #560]\n" - "fmla %[r5].4s, v18.4s, %[x0a].s[0]\n" - "ldr q18, [%[a_ptr], #576]\n" - "fmla %[r6].4s, v19.4s, %[x0a].s[0]\n" - "ldr q19, [%[a_ptr], #592]\n" - "fmla %[r7].4s, v20.4s, %[x0a].s[0]\n" - "ldr q20, [%[a_ptr], #608]\n" ASM_PREFETCH("[%[a_ptr], #2240]") + "fmla %[r0].4s, v13.4s, %[x0a].s[0]\n" + "ldr %q[x0], [%[x_ptr]]\n" + "fmla %[r1].4s, v14.4s, %[x0a].s[0]\n" + "ldr q14, [%[a_ptr], #512]\n" + "fmla %[r2].4s, v15.4s, %[x0a].s[0]\n" + "ldr q15, [%[a_ptr], #528]\n" + "fmla %[r3].4s, v16.4s, %[x0a].s[0]\n" + ASM_PREFETCH("[%[a_ptr], #2176]") + "ldr q16, [%[a_ptr], #544]\n" + "fmla %[r4].4s, v17.4s, %[x0a].s[0]\n" + "ldr q17, [%[a_ptr], #560]\n" + "fmla %[r5].4s, v18.4s, %[x0a].s[0]\n" + "ldr q18, [%[a_ptr], #576]\n" + "fmla %[r6].4s, v19.4s, %[x0a].s[0]\n" + "ldr q19, [%[a_ptr], #592]\n" + "fmla %[r7].4s, v20.4s, %[x0a].s[0]\n" + "ldr q20, [%[a_ptr], #608]\n" + ASM_PREFETCH("[%[a_ptr], #2240]") // Unroll 5 - "fmla %[r0].4s, v21.4s, %[x0a].s[1]\n" - "ldr q21, [%[a_ptr], #624]\n" - "fmla %[r1].4s, v22.4s, %[x0a].s[1]\n" - "ldr q22, [%[a_ptr], #640]\n" - "fmla %[r2].4s, v23.4s, %[x0a].s[1]\n" - "ldr q23, [%[a_ptr], #656]\n" - "fmla %[r3].4s, v2.4s, %[x0a].s[1]\n" - "ldr q2, [%[a_ptr], #672]\n" ASM_PREFETCH("[%[a_ptr], #2304]") - "fmla %[r4].4s, v3.4s, %[x0a].s[1]\n" - "ldr q3, [%[a_ptr], #688]\n" - "fmla %[r5].4s, v4.4s, %[x0a].s[1]\n" - "ldr q4, [%[a_ptr], #704]\n" - "fmla %[r6].4s, v5.4s, %[x0a].s[1]\n" - "ldr q5, [%[a_ptr], #720]\n" - "fmla %[r7].4s, v6.4s, %[x0a].s[1]\n" - "ldr q6, [%[a_ptr], #736]\n" ASM_PREFETCH("[%[a_ptr], #2368]") + "fmla %[r0].4s, v21.4s, %[x0a].s[1]\n" + "ldr q21, [%[a_ptr], #624]\n" + "fmla %[r1].4s, v22.4s, %[x0a].s[1]\n" + "ldr q22, [%[a_ptr], #640]\n" + "fmla %[r2].4s, v23.4s, %[x0a].s[1]\n" + "ldr q23, [%[a_ptr], #656]\n" + "fmla %[r3].4s, v2.4s, %[x0a].s[1]\n" + "ldr q2, [%[a_ptr], #672]\n" + ASM_PREFETCH("[%[a_ptr], #2304]") + "fmla %[r4].4s, v3.4s, %[x0a].s[1]\n" + "ldr q3, [%[a_ptr], #688]\n" + "fmla %[r5].4s, v4.4s, %[x0a].s[1]\n" + "ldr q4, [%[a_ptr], #704]\n" + "fmla %[r6].4s, v5.4s, %[x0a].s[1]\n" + "ldr q5, [%[a_ptr], #720]\n" + "fmla %[r7].4s, v6.4s, %[x0a].s[1]\n" + "ldr q6, [%[a_ptr], #736]\n" + ASM_PREFETCH("[%[a_ptr], #2368]") // Unroll 6 - "fmla %[r0].4s, v7.4s, %[x0a].s[2]\n" - "ldr q7, [%[a_ptr], #752]\n" - "fmla %[r1].4s, v8.4s, %[x0a].s[2]\n" - "ldr q8, [%[a_ptr], #768]\n" - "fmla %[r2].4s, v9.4s, %[x0a].s[2]\n" - "ldr q9, [%[a_ptr], #784]\n" - "fmla %[r3].4s, v10.4s, %[x0a].s[2]\n" - "ldr q10, [%[a_ptr], #800]\n" ASM_PREFETCH("[%[a_ptr], #2432]") - "fmla %[r4].4s, v11.4s, %[x0a].s[2]\n" - "ldr q11, [%[a_ptr], #816]\n" - "fmla %[r5].4s, v12.4s, %[x0a].s[2]\n" - "ldr q12, [%[a_ptr], #832]\n" - "fmla %[r6].4s, v14.4s, %[x0a].s[2]\n" - "ldr q13, [%[a_ptr], #848]\n" - "ldr q14, [%[a_ptr], #864]\n" - "fmla %[r7].4s, v15.4s, %[x0a].s[2]\n" - "ldr q15, [%[a_ptr], #880]\n" ASM_PREFETCH("[%[a_ptr], #2496]") + "fmla %[r0].4s, v7.4s, %[x0a].s[2]\n" + "ldr q7, [%[a_ptr], #752]\n" + "fmla %[r1].4s, v8.4s, %[x0a].s[2]\n" + "ldr q8, [%[a_ptr], #768]\n" + "fmla %[r2].4s, v9.4s, %[x0a].s[2]\n" + "ldr q9, [%[a_ptr], #784]\n" + "fmla %[r3].4s, v10.4s, %[x0a].s[2]\n" + "ldr q10, [%[a_ptr], #800]\n" + ASM_PREFETCH("[%[a_ptr], #2432]") + "fmla %[r4].4s, v11.4s, %[x0a].s[2]\n" + "ldr q11, [%[a_ptr], #816]\n" + "fmla %[r5].4s, v12.4s, %[x0a].s[2]\n" + "ldr q12, [%[a_ptr], #832]\n" + "fmla %[r6].4s, v14.4s, %[x0a].s[2]\n" + "ldr q13, [%[a_ptr], #848]\n" + "ldr q14, [%[a_ptr], #864]\n" + "fmla %[r7].4s, v15.4s, %[x0a].s[2]\n" + "ldr q15, [%[a_ptr], #880]\n" + ASM_PREFETCH("[%[a_ptr], #2496]") // Unroll 7 - "fmla %[r0].4s, v16.4s, %[x0a].s[3]\n" - "ldr q16, [%[a_ptr], #896]\n" - "fmla %[r1].4s, v17.4s, %[x0a].s[3]\n" - "ldr q17, [%[a_ptr], #912]\n" - "fmla %[r2].4s, v18.4s, %[x0a].s[3]\n" - "ldr q18, [%[a_ptr], #928]\n" - "fmla %[r3].4s, v19.4s, %[x0a].s[3]\n" ASM_PREFETCH("[%[a_ptr], #2560]") - "ldr q19, [%[a_ptr], #944]\n" - "fmla %[r4].4s, v20.4s, %[x0a].s[3]\n" - "ldr q20, [%[a_ptr], #960]\n" - "fmla %[r5].4s, v21.4s, %[x0a].s[3]\n" - "ldr q21, [%[a_ptr], #976]\n" - "add %[a_ptr], %[a_ptr], #1024\n" - "fmla %[r6].4s, v22.4s, %[x0a].s[3]\n" - "ldr q22, [%[a_ptr], #-32]\n" - "fmla %[r7].4s, v23.4s, %[x0a].s[3]\n" - "ldr q23, [%[a_ptr], #-16]\n" ASM_PREFETCH("[%[a_ptr], #1600]") - "bne 1b\n" + "fmla %[r0].4s, v16.4s, %[x0a].s[3]\n" + "ldr q16, [%[a_ptr], #896]\n" + "fmla %[r1].4s, v17.4s, %[x0a].s[3]\n" + "ldr q17, [%[a_ptr], #912]\n" + "fmla %[r2].4s, v18.4s, %[x0a].s[3]\n" + "ldr q18, [%[a_ptr], #928]\n" + "fmla %[r3].4s, v19.4s, %[x0a].s[3]\n" + ASM_PREFETCH("[%[a_ptr], #2560]") + "ldr q19, [%[a_ptr], #944]\n" + "fmla %[r4].4s, v20.4s, %[x0a].s[3]\n" + "ldr q20, [%[a_ptr], #960]\n" + "fmla %[r5].4s, v21.4s, %[x0a].s[3]\n" + "ldr q21, [%[a_ptr], #976]\n" + "add %[a_ptr], %[a_ptr], #1024\n" + "fmla %[r6].4s, v22.4s, %[x0a].s[3]\n" + "ldr q22, [%[a_ptr], #-32]\n" + "fmla %[r7].4s, v23.4s, %[x0a].s[3]\n" + "ldr q23, [%[a_ptr], #-16]\n" + ASM_PREFETCH("[%[a_ptr], #1600]") + "bne 1b\n" // Detached final iteration "2:\n" // Unroll 0 - "fmla %[r0].4s, v2.4s, %[x0].s[0]\n" - "ldr %q[x0a], [%[x_ptr], #16]\n" - "fmla %[r1].4s, v3.4s, %[x0].s[0]\n" - "ldr q3, [%[a_ptr], #0]\n" - "subs %w[k], %w[k], #1\n" - "fmla %[r2].4s, v4.4s, %[x0].s[0]\n" - "ldr q4, [%[a_ptr], #16]\n" - "fmla %[r3].4s, v5.4s, %[x0].s[0]\n" - "ldr q5, [%[a_ptr], #32]\n" - "add %[x_ptr], %[x_ptr], #32\n" - "fmla %[r4].4s, v6.4s, %[x0].s[0]\n" - "ldr q6, [%[a_ptr], #48]\n" - "fmla %[r5].4s, v7.4s, %[x0].s[0]\n" - "ldr q7, [%[a_ptr], #64]\n" - "fmla %[r6].4s, v8.4s, %[x0].s[0]\n" - "ldr q8, [%[a_ptr], #80]\n" - "fmla %[r7].4s, v9.4s, %[x0].s[0]\n" - "ldr q9, [%[a_ptr], #96]\n" + "fmla %[r0].4s, v2.4s, %[x0].s[0]\n" + "ldr %q[x0a], [%[x_ptr], #16]\n" + "fmla %[r1].4s, v3.4s, %[x0].s[0]\n" + "ldr q3, [%[a_ptr], #0]\n" + "subs %w[k], %w[k], #1\n" + "fmla %[r2].4s, v4.4s, %[x0].s[0]\n" + "ldr q4, [%[a_ptr], #16]\n" + "fmla %[r3].4s, v5.4s, %[x0].s[0]\n" + "ldr q5, [%[a_ptr], #32]\n" + "add %[x_ptr], %[x_ptr], #32\n" + "fmla %[r4].4s, v6.4s, %[x0].s[0]\n" + "ldr q6, [%[a_ptr], #48]\n" + "fmla %[r5].4s, v7.4s, %[x0].s[0]\n" + "ldr q7, [%[a_ptr], #64]\n" + "fmla %[r6].4s, v8.4s, %[x0].s[0]\n" + "ldr q8, [%[a_ptr], #80]\n" + "fmla %[r7].4s, v9.4s, %[x0].s[0]\n" + "ldr q9, [%[a_ptr], #96]\n" // Unroll 1 - "fmla %[r0].4s, v10.4s, %[x0].s[1]\n" - "ldr q10, [%[a_ptr], #112]\n" - "fmla %[r1].4s, v11.4s, %[x0].s[1]\n" - "ldr q11, [%[a_ptr], #128]\n" - "fmla %[r2].4s, v12.4s, %[x0].s[1]\n" - "ldr q12, [%[a_ptr], #144]\n" - "fmla %[r3].4s, v13.4s, %[x0].s[1]\n" - "ldr q13, [%[a_ptr], #160]\n" - "fmla %[r4].4s, v14.4s, %[x0].s[1]\n" - "ldr q14, [%[a_ptr], #176]\n" - "fmla %[r5].4s, v15.4s, %[x0].s[1]\n" - "ldr q15, [%[a_ptr], #192]\n" - "fmla %[r6].4s, v16.4s, %[x0].s[1]\n" - "ldr q16, [%[a_ptr], #208]\n" - "fmla %[r7].4s, v17.4s, %[x0].s[1]\n" - "ldr q17, [%[a_ptr], #224]\n" + "fmla %[r0].4s, v10.4s, %[x0].s[1]\n" + "ldr q10, [%[a_ptr], #112]\n" + "fmla %[r1].4s, v11.4s, %[x0].s[1]\n" + "ldr q11, [%[a_ptr], #128]\n" + "fmla %[r2].4s, v12.4s, %[x0].s[1]\n" + "ldr q12, [%[a_ptr], #144]\n" + "fmla %[r3].4s, v13.4s, %[x0].s[1]\n" + "ldr q13, [%[a_ptr], #160]\n" + "fmla %[r4].4s, v14.4s, %[x0].s[1]\n" + "ldr q14, [%[a_ptr], #176]\n" + "fmla %[r5].4s, v15.4s, %[x0].s[1]\n" + "ldr q15, [%[a_ptr], #192]\n" + "fmla %[r6].4s, v16.4s, %[x0].s[1]\n" + "ldr q16, [%[a_ptr], #208]\n" + "fmla %[r7].4s, v17.4s, %[x0].s[1]\n" + "ldr q17, [%[a_ptr], #224]\n" // Unroll 2 - "fmla %[r0].4s, v18.4s, %[x0].s[2]\n" - "ldr q18, [%[a_ptr], #240]\n" - "fmla %[r1].4s, v19.4s, %[x0].s[2]\n" - "ldr q19, [%[a_ptr], #256]\n" - "fmla %[r2].4s, v20.4s, %[x0].s[2]\n" - "ldr q20, [%[a_ptr], #272]\n" - "fmla %[r3].4s, v21.4s, %[x0].s[2]\n" - "ldr q21, [%[a_ptr], #288]\n" - "fmla %[r4].4s, v22.4s, %[x0].s[2]\n" - "ldr q22, [%[a_ptr], #304]\n" - "fmla %[r5].4s, v23.4s, %[x0].s[2]\n" - "ldr q23, [%[a_ptr], #320]\n" - "fmla %[r6].4s, v3.4s, %[x0].s[2]\n" - "ldr q2, [%[a_ptr], #336]\n" - "ldr q3, [%[a_ptr], #352]\n" - "fmla %[r7].4s, v4.4s, %[x0].s[2]\n" - "ldr q4, [%[a_ptr], #368]\n" + "fmla %[r0].4s, v18.4s, %[x0].s[2]\n" + "ldr q18, [%[a_ptr], #240]\n" + "fmla %[r1].4s, v19.4s, %[x0].s[2]\n" + "ldr q19, [%[a_ptr], #256]\n" + "fmla %[r2].4s, v20.4s, %[x0].s[2]\n" + "ldr q20, [%[a_ptr], #272]\n" + "fmla %[r3].4s, v21.4s, %[x0].s[2]\n" + "ldr q21, [%[a_ptr], #288]\n" + "fmla %[r4].4s, v22.4s, %[x0].s[2]\n" + "ldr q22, [%[a_ptr], #304]\n" + "fmla %[r5].4s, v23.4s, %[x0].s[2]\n" + "ldr q23, [%[a_ptr], #320]\n" + "fmla %[r6].4s, v3.4s, %[x0].s[2]\n" + "ldr q2, [%[a_ptr], #336]\n" + "ldr q3, [%[a_ptr], #352]\n" + "fmla %[r7].4s, v4.4s, %[x0].s[2]\n" + "ldr q4, [%[a_ptr], #368]\n" // Unroll 3 - "fmla %[r0].4s, v5.4s, %[x0].s[3]\n" - "ldr q5, [%[a_ptr], #384]\n" - "fmla %[r1].4s, v6.4s, %[x0].s[3]\n" - "ldr q6, [%[a_ptr], #400]\n" - "fmla %[r2].4s, v7.4s, %[x0].s[3]\n" - "ldr q7, [%[a_ptr], #416]\n" - "fmla %[r3].4s, v8.4s, %[x0].s[3]\n" - "ldr q8, [%[a_ptr], #432]\n" - "fmla %[r4].4s, v9.4s, %[x0].s[3]\n" - "ldr q9, [%[a_ptr], #448]\n" - "fmla %[r5].4s, v10.4s, %[x0].s[3]\n" - "ldr q10, [%[a_ptr], #464]\n" - "fmla %[r6].4s, v11.4s, %[x0].s[3]\n" - "ldr q11, [%[a_ptr], #480]\n" - "fmla %[r7].4s, v12.4s, %[x0].s[3]\n" - "ldr q12, [%[a_ptr], #496]\n" + "fmla %[r0].4s, v5.4s, %[x0].s[3]\n" + "ldr q5, [%[a_ptr], #384]\n" + "fmla %[r1].4s, v6.4s, %[x0].s[3]\n" + "ldr q6, [%[a_ptr], #400]\n" + "fmla %[r2].4s, v7.4s, %[x0].s[3]\n" + "ldr q7, [%[a_ptr], #416]\n" + "fmla %[r3].4s, v8.4s, %[x0].s[3]\n" + "ldr q8, [%[a_ptr], #432]\n" + "fmla %[r4].4s, v9.4s, %[x0].s[3]\n" + "ldr q9, [%[a_ptr], #448]\n" + "fmla %[r5].4s, v10.4s, %[x0].s[3]\n" + "ldr q10, [%[a_ptr], #464]\n" + "fmla %[r6].4s, v11.4s, %[x0].s[3]\n" + "ldr q11, [%[a_ptr], #480]\n" + "fmla %[r7].4s, v12.4s, %[x0].s[3]\n" + "ldr q12, [%[a_ptr], #496]\n" // Unroll 4 - "fmla %[r0].4s, v13.4s, %[x0a].s[0]\n" - "fmla %[r1].4s, v14.4s, %[x0a].s[0]\n" - "ldr q14, [%[a_ptr], #512]\n" - "fmla %[r2].4s, v15.4s, %[x0a].s[0]\n" - "ldr q15, [%[a_ptr], #528]\n" - "fmla %[r3].4s, v16.4s, %[x0a].s[0]\n" - "ldr q16, [%[a_ptr], #544]\n" - "fmla %[r4].4s, v17.4s, %[x0a].s[0]\n" - "ldr q17, [%[a_ptr], #560]\n" - "fmla %[r5].4s, v18.4s, %[x0a].s[0]\n" - "ldr q18, [%[a_ptr], #576]\n" - "fmla %[r6].4s, v19.4s, %[x0a].s[0]\n" - "ldr q19, [%[a_ptr], #592]\n" - "fmla %[r7].4s, v20.4s, %[x0a].s[0]\n" - "ldr q20, [%[a_ptr], #608]\n" + "fmla %[r0].4s, v13.4s, %[x0a].s[0]\n" + "fmla %[r1].4s, v14.4s, %[x0a].s[0]\n" + "ldr q14, [%[a_ptr], #512]\n" + "fmla %[r2].4s, v15.4s, %[x0a].s[0]\n" + "ldr q15, [%[a_ptr], #528]\n" + "fmla %[r3].4s, v16.4s, %[x0a].s[0]\n" + "ldr q16, [%[a_ptr], #544]\n" + "fmla %[r4].4s, v17.4s, %[x0a].s[0]\n" + "ldr q17, [%[a_ptr], #560]\n" + "fmla %[r5].4s, v18.4s, %[x0a].s[0]\n" + "ldr q18, [%[a_ptr], #576]\n" + "fmla %[r6].4s, v19.4s, %[x0a].s[0]\n" + "ldr q19, [%[a_ptr], #592]\n" + "fmla %[r7].4s, v20.4s, %[x0a].s[0]\n" + "ldr q20, [%[a_ptr], #608]\n" // Unroll 5 - "fmla %[r0].4s, v21.4s, %[x0a].s[1]\n" - "ldr q21, [%[a_ptr], #624]\n" - "fmla %[r1].4s, v22.4s, %[x0a].s[1]\n" - "ldr q22, [%[a_ptr], #640]\n" - "fmla %[r2].4s, v23.4s, %[x0a].s[1]\n" - "ldr q23, [%[a_ptr], #656]\n" - "fmla %[r3].4s, v2.4s, %[x0a].s[1]\n" - "add %[a_ptr], %[a_ptr], #672\n" - "fmla %[r4].4s, v3.4s, %[x0a].s[1]\n" - "fmla %[r5].4s, v4.4s, %[x0a].s[1]\n" - "fmla %[r6].4s, v5.4s, %[x0a].s[1]\n" - "fmla %[r7].4s, v6.4s, %[x0a].s[1]\n" + "fmla %[r0].4s, v21.4s, %[x0a].s[1]\n" + "ldr q21, [%[a_ptr], #624]\n" + "fmla %[r1].4s, v22.4s, %[x0a].s[1]\n" + "ldr q22, [%[a_ptr], #640]\n" + "fmla %[r2].4s, v23.4s, %[x0a].s[1]\n" + "ldr q23, [%[a_ptr], #656]\n" + "fmla %[r3].4s, v2.4s, %[x0a].s[1]\n" + "add %[a_ptr], %[a_ptr], #672\n" + "fmla %[r4].4s, v3.4s, %[x0a].s[1]\n" + "fmla %[r5].4s, v4.4s, %[x0a].s[1]\n" + "fmla %[r6].4s, v5.4s, %[x0a].s[1]\n" + "fmla %[r7].4s, v6.4s, %[x0a].s[1]\n" // Unroll 6 - "fmla %[r0].4s, v7.4s, %[x0a].s[2]\n" - "fmla %[r1].4s, v8.4s, %[x0a].s[2]\n" - "fmla %[r2].4s, v9.4s, %[x0a].s[2]\n" - "fmla %[r3].4s, v10.4s, %[x0a].s[2]\n" - "fmla %[r4].4s, v11.4s, %[x0a].s[2]\n" - "fmla %[r5].4s, v12.4s, %[x0a].s[2]\n" - "fmla %[r6].4s, v14.4s, %[x0a].s[2]\n" - "fmla %[r7].4s, v15.4s, %[x0a].s[2]\n" + "fmla %[r0].4s, v7.4s, %[x0a].s[2]\n" + "fmla %[r1].4s, v8.4s, %[x0a].s[2]\n" + "fmla %[r2].4s, v9.4s, %[x0a].s[2]\n" + "fmla %[r3].4s, v10.4s, %[x0a].s[2]\n" + "fmla %[r4].4s, v11.4s, %[x0a].s[2]\n" + "fmla %[r5].4s, v12.4s, %[x0a].s[2]\n" + "fmla %[r6].4s, v14.4s, %[x0a].s[2]\n" + "fmla %[r7].4s, v15.4s, %[x0a].s[2]\n" // Unroll 7 - "fmla %[r0].4s, v16.4s, %[x0a].s[3]\n" - "fmla %[r1].4s, v17.4s, %[x0a].s[3]\n" - "fmla %[r2].4s, v18.4s, %[x0a].s[3]\n" - "fmla %[r3].4s, v19.4s, %[x0a].s[3]\n" - "fmla %[r4].4s, v20.4s, %[x0a].s[3]\n" - "fmla %[r5].4s, v21.4s, %[x0a].s[3]\n" - "fmla %[r6].4s, v22.4s, %[x0a].s[3]\n" - "fmla %[r7].4s, v23.4s, %[x0a].s[3]\n" - : - [a_ptr] "+r"(a_ptr), [x_ptr] "+r"(x_ptr), - [x0] "+w"(x0), [x0a] "+w"(x0a), [k] "+r"(k), - [r0] "+w"(r0), [r1] "+w"(r1), [r2] "+w"(r2), [r3] "+w"(r3), - [r4] "+w"(r4), [r5] "+w"(r5), [r6] "+w"(r6), [r7] "+w"(r7) - : - : "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", - "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "x20", "x21", "cc", "memory"); + "fmla %[r0].4s, v16.4s, %[x0a].s[3]\n" + "fmla %[r1].4s, v17.4s, %[x0a].s[3]\n" + "fmla %[r2].4s, v18.4s, %[x0a].s[3]\n" + "fmla %[r3].4s, v19.4s, %[x0a].s[3]\n" + "fmla %[r4].4s, v20.4s, %[x0a].s[3]\n" + "fmla %[r5].4s, v21.4s, %[x0a].s[3]\n" + "fmla %[r6].4s, v22.4s, %[x0a].s[3]\n" + "fmla %[r7].4s, v23.4s, %[x0a].s[3]\n" + : + [a_ptr] "+r" (a_ptr), [x_ptr] "+r" (x_ptr), + [x0] "+w" (x0), [x0a] "+w" (x0a), [k] "+r" (k), + [r0] "+w" (r0), [r1] "+w" (r1), [r2] "+w" (r2), [r3] "+w" (r3), + [r4] "+w" (r4), [r5] "+w" (r5), [r6] "+w" (r6), [r7] "+w" (r7) + : + : "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "x20", "x21", "cc", "memory"); } // Deal with ragged M - if(M % 8) - { - int l = (M % 8) - 1; - - __asm __volatile( - "ldr q2, [%[a_ptr], #0]\n" - "ldr q3, [%[a_ptr], #16]\n" - "ldr q4, [%[a_ptr], #32]\n" - "ldr q5, [%[a_ptr], #48]\n" - "ldr q6, [%[a_ptr], #64]\n" - "ldr q7, [%[a_ptr], #80]\n" - "ldr q8, [%[a_ptr], #96]\n" - "ldr q9, [%[a_ptr], #112]\n" - "ldr %s[x0], [%[x_ptr]]\n" - "add %[a_ptr], %[a_ptr], #128\n" - "add %[x_ptr], %[x_ptr], #4\n" - - "cbz %w[l], 2f\n" + if (M % 8) { + int l=(M%8)-1; + + __asm __volatile ( + "ldr q2, [%[a_ptr], #0]\n" + "ldr q3, [%[a_ptr], #16]\n" + "ldr q4, [%[a_ptr], #32]\n" + "ldr q5, [%[a_ptr], #48]\n" + "ldr q6, [%[a_ptr], #64]\n" + "ldr q7, [%[a_ptr], #80]\n" + "ldr q8, [%[a_ptr], #96]\n" + "ldr q9, [%[a_ptr], #112]\n" + "ldr %s[x0], [%[x_ptr]]\n" + "add %[a_ptr], %[a_ptr], #128\n" + "add %[x_ptr], %[x_ptr], #4\n" + + "cbz %w[l], 2f\n" "1:\n" - "fmla %[r0].4s, v2.4s, %[x0].s[0]\n" - "ldr q2, [%[a_ptr], #0]\n" - "subs %w[l], %w[l], #1\n" - "fmla %[r1].4s, v3.4s, %[x0].s[0]\n" - "ldr q3, [%[a_ptr], #16]\n" - "fmla %[r2].4s, v4.4s, %[x0].s[0]\n" - "ldr q4, [%[a_ptr], #32]\n" - "fmla %[r3].4s, v5.4s, %[x0].s[0]\n" - "ldr q5, [%[a_ptr], #48]\n" - "fmla %[r4].4s, v6.4s, %[x0].s[0]\n" - "ldr q6, [%[a_ptr], #64]\n" - "fmla %[r5].4s, v7.4s, %[x0].s[0]\n" - "ldr q7, [%[a_ptr], #80]\n" - "fmla %[r6].4s, v8.4s, %[x0].s[0]\n" - "ldr q8, [%[a_ptr], #96]\n" - "fmla %[r7].4s, v9.4s, %[x0].s[0]\n" - "ldr q9, [%[a_ptr], #112]\n" - "ldr %s[x0], [%[x_ptr]]\n" - "add %[a_ptr], %[a_ptr], #128\n" - "add %[x_ptr], %[x_ptr], #4\n" - "bne 1b\n" + "fmla %[r0].4s, v2.4s, %[x0].s[0]\n" + "ldr q2, [%[a_ptr], #0]\n" + "subs %w[l], %w[l], #1\n" + "fmla %[r1].4s, v3.4s, %[x0].s[0]\n" + "ldr q3, [%[a_ptr], #16]\n" + "fmla %[r2].4s, v4.4s, %[x0].s[0]\n" + "ldr q4, [%[a_ptr], #32]\n" + "fmla %[r3].4s, v5.4s, %[x0].s[0]\n" + "ldr q5, [%[a_ptr], #48]\n" + "fmla %[r4].4s, v6.4s, %[x0].s[0]\n" + "ldr q6, [%[a_ptr], #64]\n" + "fmla %[r5].4s, v7.4s, %[x0].s[0]\n" + "ldr q7, [%[a_ptr], #80]\n" + "fmla %[r6].4s, v8.4s, %[x0].s[0]\n" + "ldr q8, [%[a_ptr], #96]\n" + "fmla %[r7].4s, v9.4s, %[x0].s[0]\n" + "ldr q9, [%[a_ptr], #112]\n" + "ldr %s[x0], [%[x_ptr]]\n" + "add %[a_ptr], %[a_ptr], #128\n" + "add %[x_ptr], %[x_ptr], #4\n" + "bne 1b\n" "2:\n" - "fmla %[r0].4s, v2.4s, %[x0].s[0]\n" - "fmla %[r1].4s, v3.4s, %[x0].s[0]\n" - "fmla %[r2].4s, v4.4s, %[x0].s[0]\n" - "fmla %[r3].4s, v5.4s, %[x0].s[0]\n" - "fmla %[r4].4s, v6.4s, %[x0].s[0]\n" - "fmla %[r5].4s, v7.4s, %[x0].s[0]\n" - "fmla %[r6].4s, v8.4s, %[x0].s[0]\n" - "fmla %[r7].4s, v9.4s, %[x0].s[0]\n" - : - [a_ptr] "+r"(a_ptr), [x_ptr] "+r"(x_ptr), - [x0] "+w"(x0), [l] "+r"(l), - [r0] "+w"(r0), [r1] "+w"(r1), [r2] "+w"(r2), [r3] "+w"(r3), - [r4] "+w"(r4), [r5] "+w"(r5), [r6] "+w"(r6), [r7] "+w"(r7) - : - : "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "cc", "memory"); + "fmla %[r0].4s, v2.4s, %[x0].s[0]\n" + "fmla %[r1].4s, v3.4s, %[x0].s[0]\n" + "fmla %[r2].4s, v4.4s, %[x0].s[0]\n" + "fmla %[r3].4s, v5.4s, %[x0].s[0]\n" + "fmla %[r4].4s, v6.4s, %[x0].s[0]\n" + "fmla %[r5].4s, v7.4s, %[x0].s[0]\n" + "fmla %[r6].4s, v8.4s, %[x0].s[0]\n" + "fmla %[r7].4s, v9.4s, %[x0].s[0]\n" + : + [a_ptr] "+r" (a_ptr), [x_ptr] "+r" (x_ptr), + [x0] "+w" (x0), [l] "+r" (l), + [r0] "+w" (r0), [r1] "+w" (r1), [r2] "+w" (r2), [r3] "+w" (r3), + [r4] "+w" (r4), [r5] "+w" (r5), [r6] "+w" (r6), [r7] "+w" (r7) + : + : "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "cc", "memory"); } - if(l == 32) - { + if (l==32) { // Fast path vst1q_f32(y_ptr, r0); vst1q_f32(y_ptr + 4, r1); @@ -650,82 +598,48 @@ void a64_sgemv_pretransposed(const float *A, int lda, const float *X, float *Y, vst1q_f32(y_ptr + 20, r5); vst1q_f32(y_ptr + 24, r6); vst1q_f32(y_ptr + 28, r7); - } - else - { - int vecs = l / 4; - int oddbits = l % 4; + } else { + int vecs=l/4; + int oddbits=l%4; - if(oddbits) - { + if (oddbits) { // As above - slowest path deals with vectors plus odd bits float32x4_t oddvec; - do - { - if(vecs == 0) - { - oddvec = r0; - break; - } + do { + if (vecs==0) { oddvec=r0; break; } vst1q_f32(y_ptr, r0); - if(--vecs == 0) - { - oddvec = r1; - break; - } + if (--vecs==0) { oddvec=r1; break; } vst1q_f32(y_ptr + 4, r1); - if(--vecs == 0) - { - oddvec = r2; - break; - } + if (--vecs==0) { oddvec=r2; break; } vst1q_f32(y_ptr + 8, r2); - if(--vecs == 0) - { - oddvec = r3; - break; - } + if (--vecs==0) { oddvec=r3; break; } vst1q_f32(y_ptr + 12, r3); - if(--vecs == 0) - { - oddvec = r4; - break; - } + if (--vecs==0) { oddvec=r4; break; } vst1q_f32(y_ptr + 16, r4); - if(--vecs == 0) - { - oddvec = r5; - break; - } + if (--vecs==0) { oddvec=r5; break; } vst1q_f32(y_ptr + 20, r5); - if(--vecs == 0) - { - oddvec = r6; - break; - } + if (--vecs==0) { oddvec=r6; break; } vst1q_f32(y_ptr + 24, r6); - oddvec = r7; - } - while(0); + oddvec=r7; + } while (0); float *oddbase = y_ptr + l - oddbits; - switch(oddbits) - { + switch(oddbits) { case 3: vst1q_lane_f32(oddbase + 2, oddvec, 2); - // fall through + // fall through case 2: vst1q_lane_f32(oddbase + 1, oddvec, 1); - // fall through + // fall through case 1: vst1q_lane_f32(oddbase, oddvec, 0); break; @@ -734,56 +648,31 @@ void a64_sgemv_pretransposed(const float *A, int lda, const float *X, float *Y, // oddbits must be 1, 2 or 3. UNREACHABLE("Impossible case in switch."); } - } - else - { + } else { // As above - medium path deals with vectors only - do - { - if(vecs == 0) - { - UNREACHABLE("vecs and oddbits can't both be 0"); - } + do { + if (vecs==0) { UNREACHABLE("vecs and oddbits can't both be 0"); } vst1q_f32(y_ptr, r0); - if(--vecs == 0) - { - break; - } + if (--vecs==0) { break; } vst1q_f32(y_ptr + 4, r1); - if(--vecs == 0) - { - break; - } + if (--vecs==0) { break; } vst1q_f32(y_ptr + 8, r2); - if(--vecs == 0) - { - break; - } + if (--vecs==0) { break; } vst1q_f32(y_ptr + 12, r3); - if(--vecs == 0) - { - break; - } + if (--vecs==0) { break; } vst1q_f32(y_ptr + 16, r4); - if(--vecs == 0) - { - break; - } + if (--vecs==0) { break; } vst1q_f32(y_ptr + 20, r5); - if(--vecs == 0) - { - break; - } + if (--vecs==0) { break; } vst1q_f32(y_ptr + 24, r6); - } - while(0); + } while (0); } } } diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans.hpp index 5b9bd72c89..18c5c3a6dc 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans.hpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -25,14 +25,13 @@ #ifdef __aarch64__ -namespace arm_gemm -{ +namespace arm_gemm { + // Actual kernel implementations void a64_sgemv_trans(const float *, const float *, float *, float, int, int, int); // Transposed SGEMV strategy class. -class sgemv_trans -{ +class sgemv_trans { public: typedef float operand_type; typedef float result_type; @@ -41,13 +40,11 @@ public: /* Kernel blocking parameters */ static const int out_width = 96; - static const int k_unroll = 1; + static const int k_unroll = 1; - kern_type kernel = a64_sgemv_trans; + kern_type kernel=a64_sgemv_trans; - sgemv_trans(const CPUInfo *ci) - { - } + sgemv_trans(const CPUInfo *ci) { } }; } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp index 8fa403bf02..64ef9d89a4 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp @@ -42,464 +42,472 @@ // higher performance, but that's left to the outer loop. In this kernel we // process all of M at the same time. + // How far ahead to prefetch for the first and subsequent prefetches. // These values work for A72 on JunoR2... #define FIRST_PFD 9 #define PFD 6 -namespace arm_gemm -{ -void a64_sgemv_trans(const float *Astart, const float *Xstart, float *Ystart, float beta, int lda, int M, int N) -{ +namespace arm_gemm { + +void a64_sgemv_trans(const float *Astart, const float *Xstart, float *Ystart, float beta, int lda, int M, int N) { const float *a_ptr_base = Astart; - float *y_ptr = Ystart; + float *y_ptr = Ystart; register const float32x4_t vb asm("v1") = vdupq_n_f32(beta); - int firstpfd = FIRST_PFD; - if(firstpfd > M) - { - firstpfd = (M - 1); + int firstpfd=FIRST_PFD; + if (firstpfd > M) { + firstpfd = (M-1); } int pfd = PFD; - if(pfd > M) - { - pfd = (M - 1); + if (pfd > M) { + pfd = (M-1); } ptrdiff_t jump = lda * sizeof(int); - for(; N >= 96; N -= 96) - { - int k = M - 1; + for (;N>=96;N-=96) { + int k = M-1; - const float *a_ptr = a_ptr_base; - const float *x_ptr = Xstart; - const float *pf_ptr = a_ptr; + const float *a_ptr = a_ptr_base; + const float *x_ptr = Xstart; + const float *pf_ptr = a_ptr; const float *firstpf_ptr = a_ptr; - const float *pf_limit = a_ptr + (M * lda); + const float *pf_limit = a_ptr + (M * lda); - for(int i = 0; i < firstpfd; i++) - { + for (int i=0; i 0) - { + if (N>0) { // Handle N tail - up to 95 stragglers. // This is 0-23 vectors, plus optionally an 64-bit vector and/or a // single value for the remainder. // Independent pointers into the matrix for the odd 2 and odd 1. // Double up as flag to indicate whether they are needed. - const float *odd2_aptr = NULL; - const float *odd1_aptr = NULL; + const float *odd2_aptr=NULL; + const float *odd1_aptr=NULL; // Figure out how much work we need to do. - int numvecs = N / 4; - int rem = N % 4; - int k = M; + int numvecs = N/4; + int rem = N%4; + int k=M; // Set up pointers for the odd 2/1 if needed. - if(rem >= 2) - { + if (rem >= 2) { odd2_aptr = a_ptr_base + (numvecs * 4); } - if(rem & 1) - { - odd1_aptr = a_ptr_base + (numvecs * 4) + (odd2_aptr == NULL ? 0 : 2); + if (rem & 1) { + odd1_aptr = a_ptr_base + (numvecs * 4) + (odd2_aptr==NULL ? 0 : 2); } - const float *a_ptr = a_ptr_base; + const float *a_ptr = a_ptr_base; const float *firstpf_ptr = a_ptr_base; - const float *pf_ptr = a_ptr_base; - const float *pf_limit = a_ptr + (M * lda); + const float *pf_ptr = a_ptr_base; + const float *pf_limit = a_ptr + (M * lda); const float *x_ptr = Xstart; - int vecs = 0; // Working variable to count how many vectors to work on. - int dopf = 1; // Track whether we are doing prefetches. + int vecs=0; // Working variable to count how many vectors to work on. + int dopf=1; // Track whether we are doing prefetches. // Figure out how many cache lines we need to prefetch each time. int numpfs = (N + 15) / 16; // Do initial prefetches - for(int i = 0; i < firstpfd + 1; i++) - { + for (int i=0; i 1) - { - for(int i = 0; i < pfd + 1; i++) - { - switch(numpfs) - { + if (numpfs > 1) { + for (int i=0; i -inline void MergeResults(Tout *out, const Tin *in, int ldc, int y0, int ymax, int x0, int xmax, const Tout alpha, const Tout beta) -{ +namespace arm_gemm { + +template +inline void MergeResults(Tout * out, const Tin * in, int ldc, int y0, int ymax, int x0, int xmax, const Tout alpha, const Tout beta) { int full_y_blocks = (ymax - y0) / height; - int y_remainder = (ymax - y0) % height; - int y_blocks = full_y_blocks + (y_remainder ? 1 : 0); + int y_remainder = (ymax - y0) % height; + int y_blocks = full_y_blocks + (y_remainder ? 1 : 0); int full_x_blocks = (xmax - x0) / width; - int x_remainder = (xmax - x0) % width; - int x_blocks = full_x_blocks + (x_remainder ? 1 : 0); + int x_remainder = (xmax - x0) % width; + int x_blocks = full_x_blocks + (x_remainder ? 1 : 0); - for(int y_block = 0; y_block < y_blocks; y_block++) - { + for (int y_block = 0; y_block < y_blocks; y_block++) { int ybase = y0 + (y_block * height); int fill_rows = (y_block < full_y_blocks) ? height : y_remainder; - for(int x_block = 0; x_block < x_blocks; x_block++) - { + for (int x_block = 0; x_block < x_blocks; x_block++) { int xbase = x0 + (x_block * width); int fill_cols = (x_block < full_x_blocks) ? width : x_remainder; - for(int row = 0; row < fill_rows; row++) - { - for(int col = 0; col < fill_cols; col++) - { + for (int row=0; row < fill_rows; row++) { + for (int col=0; col < fill_cols; col++) { Tout &p = out[(ybase + row) * ldc + xbase + col]; - p = (p * beta) + (alpha * in[row * width + col]); + // Special case for beta==0 - don't read the input; + // (0 * x == 0) is not always true for FP types. + if (beta == static_cast(0)) { + p = (alpha * in[row * width + col]); + } else { + p = (p * beta) + (alpha * in[row * width + col]); + } } } diff --git a/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp b/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp index b44e56499f..2b833937a8 100644 --- a/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp +++ b/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp @@ -27,9 +27,8 @@ #include -template <> -inline void MergeResults<8, 6>(float *out, const float *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const float alpha, const float beta) -{ +template<> +inline void MergeResults<8, 6>(float *out, const float *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const float alpha, const float beta) { const float *inptr = in; prefetch_6x(inptr); prefetch_6x(inptr + 96); @@ -37,8 +36,7 @@ inline void MergeResults<8, 6>(float *out, const float *in, const int ldout, con float32x4_t av = vdupq_n_f32(alpha); float32x4_t bv = vdupq_n_f32(beta); - for(int y = y0; y < ymax; y += 8) - { + for (int y=y0; y(float *out, const float *in, const int ldout, con prefetch_2x(outptr4); prefetch_2x(outptr5); - for(int i = x0; i < xmax; i += 8) - { + for (int i=x0; i= ymax) - { - switch((y + 5) - ymax) - { + if ((y+5) >= ymax) { + switch ((y + 5) - ymax) { case 4: outptr1 = dummyres; case 3: @@ -81,84 +76,168 @@ inline void MergeResults<8, 6>(float *out, const float *in, const int ldout, con } } - /* For ragged X, manually copy over the valid results. */ - if((i + 7) >= xmax) - { - for(int xi = 0; xi < 8; xi++) - { - if((i + xi) < xmax) - { - *outptr0 = (alpha * inptr[xi]) + (*outptr0 * beta); - outptr0++; - *outptr1 = (alpha * inptr[xi + 8]) + (*outptr1 * beta); - outptr1++; - *outptr2 = (alpha * inptr[xi + 16]) + (*outptr2 * beta); - outptr2++; - *outptr3 = (alpha * inptr[xi + 24]) + (*outptr3 * beta); - outptr3++; - *outptr4 = (alpha * inptr[xi + 32]) + (*outptr4 * beta); - outptr4++; - *outptr5 = (alpha * inptr[xi + 40]) + (*outptr5 * beta); - outptr5++; + if (beta == 0.0f) { + /* If beta=0, don't read the original input at all. */ + + /* For ragged X, manually copy over the valid results. */ + if ((i+7) >= xmax) { + for (int xi=0; xi<8; xi++) { + if ((i+xi) < xmax) { + *outptr0 = (alpha * inptr[xi]); + outptr0++; + *outptr1 = (alpha * inptr[xi + 8]); + outptr1++; + *outptr2 = (alpha * inptr[xi + 16]); + outptr2++; + *outptr3 = (alpha * inptr[xi + 24]); + outptr3++; + *outptr4 = (alpha * inptr[xi + 32]); + outptr4++; + *outptr5 = (alpha * inptr[xi + 40]); + outptr5++; + } } + inptr += 48; + } else { + /* Optimized routine to copy an entire block */ + __asm __volatile ( + // Rows 0-1 + "VLD1.32 {d0-d3}, [%[inptr]]!\n" + "VLD1.32 {d4-d7}, [%[inptr]]!\n" + + "VMUL.f32 q4, q0, %q[av]\n" + ASM_PREFETCH("[%[inptr], #352]") + "VMUL.f32 q5, q1, %q[av]\n" + "VST1.32 {d8-d11}, [%[outptr0]]!\n" + ASM_PREFETCH("[%[inptr], #416]") + "VMUL.f32 q6, q2, %q[av]\n" + ASM_PREFETCH("[%[inptr], #480]") + "VMUL.f32 q7, q3, %q[av]\n" + "VST1.32 {d12-d15}, [%[outptr1]]!\n" + + // Rows 2-3 + "VLD1.32 {d0-d3}, [%[inptr]]!\n" + "VLD1.32 {d4-d7}, [%[inptr]]!\n" + + "VMUL.f32 q4, q0, %q[av]\n" + ASM_PREFETCH("[%[outptr0], #96]") + "VMUL.f32 q5, q1, %q[av]\n" + "VST1.32 {d8-d11}, [%[outptr2]]!\n" + ASM_PREFETCH("[%[outptr1], #96]") + "VMUL.f32 q6, q2, %q[av]\n" + ASM_PREFETCH("[%[outptr2], #96]") + "VMUL.f32 q7, q3, %q[av]\n" + "VST1.32 {d12-d15}, [%[outptr3]]!\n" + + // Rows 4-5 + "VLD1.32 {d0-d3}, [%[inptr]]!\n" + "VLD1.32 {d4-d7}, [%[inptr]]!\n" + + "VMUL.f32 q4, q0, %q[av]\n" + ASM_PREFETCH("[%[outptr3], #96]") + "VMUL.f32 q5, q1, %q[av]\n" + "VST1.32 {d8-d11}, [%[outptr4]]!\n" + ASM_PREFETCH("[%[outptr4], #96]") + "VMUL.f32 q6, q2, %q[av]\n" + ASM_PREFETCH("[%[outptr5], #128]") + "VMUL.f32 q7, q3, %q[av]\n" + "VST1.32 {d12-d15}, [%[outptr5]]!\n" + : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3), + [outptr4] "+r" (outptr4), [outptr5] "+r" (outptr5), [inptr] "+r" (inptr) + : [av] "w" (av), [bv] "w" (bv) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7" + ); + } + } else { + /* Non-zero beta: Read output and apply beta. */ + + /* For ragged X, manually copy over the valid results. */ + if ((i+7) >= xmax) { + for (int xi=0; xi<8; xi++) { + if ((i+xi) < xmax) { + *outptr0 = (alpha * inptr[xi]) + (*outptr0 * beta); + outptr0++; + *outptr1 = (alpha * inptr[xi + 8]) + (*outptr1 * beta); + outptr1++; + *outptr2 = (alpha * inptr[xi + 16]) + (*outptr2 * beta); + outptr2++; + *outptr3 = (alpha * inptr[xi + 24]) + (*outptr3 * beta); + outptr3++; + *outptr4 = (alpha * inptr[xi + 32]) + (*outptr4 * beta); + outptr4++; + *outptr5 = (alpha * inptr[xi + 40]) + (*outptr5 * beta); + outptr5++; + } + } + inptr += 48; + } else { + /* Optimized routine to copy an entire block */ + __asm __volatile ( + // Rows 0-1 + "VLD1.32 {d8-d11}, [%[outptr0]]\n" + "VMUL.f32 q4, q4, %q[bv]\n" + "VLD1.32 {d12-d15}, [%[outptr1]]\n" + "VMUL.f32 q5, q5, %q[bv]\n" + "VLD1.32 {d0-d3}, [%[inptr]]!\n" + "VMUL.f32 q6, q6, %q[bv]\n" + "VLD1.32 {d4-d7}, [%[inptr]]!\n" + "VMUL.f32 q7, q7, %q[bv]\n" + + "VMLA.f32 q4, q0, %q[av]\n" + ASM_PREFETCH("[%[inptr], #352]") + "VMLA.f32 q5, q1, %q[av]\n" + "VST1.32 {d8-d11}, [%[outptr0]]!\n" + ASM_PREFETCH("[%[inptr], #416]") + "VMLA.f32 q6, q2, %q[av]\n" + ASM_PREFETCH("[%[inptr], #480]") + "VMLA.f32 q7, q3, %q[av]\n" + "VST1.32 {d12-d15}, [%[outptr1]]!\n" + + // Rows 2-3 + "VLD1.32 {d8-d11}, [%[outptr2]]\n" + "VMUL.f32 q4, q4, %q[bv]\n" + "VLD1.32 {d12-d15}, [%[outptr3]]\n" + "VMUL.f32 q5, q5, %q[bv]\n" + "VLD1.32 {d0-d3}, [%[inptr]]!\n" + "VMUL.f32 q6, q6, %q[bv]\n" + "VLD1.32 {d4-d7}, [%[inptr]]!\n" + "VMUL.f32 q7, q7, %q[bv]\n" + + "VMLA.f32 q4, q0, %q[av]\n" + ASM_PREFETCH("[%[outptr0], #96]") + "VMLA.f32 q5, q1, %q[av]\n" + "VST1.32 {d8-d11}, [%[outptr2]]!\n" + ASM_PREFETCH("[%[outptr1], #96]") + "VMLA.f32 q6, q2, %q[av]\n" + ASM_PREFETCH("[%[outptr2], #96]") + "VMLA.f32 q7, q3, %q[av]\n" + "VST1.32 {d12-d15}, [%[outptr3]]!\n" + + // Rows 4-5 + "VLD1.32 {d8-d11}, [%[outptr4]]\n" + "VMUL.f32 q4, q4, %q[bv]\n" + "VLD1.32 {d12-d15}, [%[outptr5]]\n" + "VMUL.f32 q5, q5, %q[bv]\n" + "VLD1.32 {d0-d3}, [%[inptr]]!\n" + "VMUL.f32 q6, q6, %q[bv]\n" + "VLD1.32 {d4-d7}, [%[inptr]]!\n" + "VMUL.f32 q7, q7, %q[bv]\n" + + "VMLA.f32 q4, q0, %q[av]\n" + ASM_PREFETCH("[%[outptr3], #96]") + "VMLA.f32 q5, q1, %q[av]\n" + "VST1.32 {d8-d11}, [%[outptr4]]!\n" + ASM_PREFETCH("[%[outptr4], #96]") + "VMLA.f32 q6, q2, %q[av]\n" + ASM_PREFETCH("[%[outptr5], #128]") + "VMLA.f32 q7, q3, %q[av]\n" + "VST1.32 {d12-d15}, [%[outptr5]]!\n" + : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3), + [outptr4] "+r" (outptr4), [outptr5] "+r" (outptr5), [inptr] "+r" (inptr) + : [av] "w" (av), [bv] "w" (bv) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7" + ); } - inptr += 48; - } - else - { - /* Optimized routine to copy an entire block */ - __asm __volatile( - // Rows 0-1 - "VLD1.32 {d8-d11}, [%[outptr0]]\n" - "VMUL.f32 q4, q4, %q[bv]\n" - "VLD1.32 {d12-d15}, [%[outptr1]]\n" - "VMUL.f32 q5, q5, %q[bv]\n" - "VLD1.32 {d0-d3}, [%[inptr]]!\n" - "VMUL.f32 q6, q6, %q[bv]\n" - "VLD1.32 {d4-d7}, [%[inptr]]!\n" - "VMUL.f32 q7, q7, %q[bv]\n" - - "VMLA.f32 q4, q0, %q[av]\n" ASM_PREFETCH("[%[inptr], #352]") - "VMLA.f32 q5, q1, %q[av]\n" - "VST1.32 {d8-d11}, [%[outptr0]]!\n" ASM_PREFETCH("[%[inptr], #416]") "VMLA.f32 q6, q2, %q[av]\n" ASM_PREFETCH("[%[inptr], #480]") - "VMLA.f32 q7, q3, %q[av]\n" - "VST1.32 {d12-d15}, [%[outptr1]]!\n" - - // Rows 2-3 - "VLD1.32 {d8-d11}, [%[outptr2]]\n" - "VMUL.f32 q4, q4, %q[bv]\n" - "VLD1.32 {d12-d15}, [%[outptr3]]\n" - "VMUL.f32 q5, q5, %q[bv]\n" - "VLD1.32 {d0-d3}, [%[inptr]]!\n" - "VMUL.f32 q6, q6, %q[bv]\n" - "VLD1.32 {d4-d7}, [%[inptr]]!\n" - "VMUL.f32 q7, q7, %q[bv]\n" - - "VMLA.f32 q4, q0, %q[av]\n" ASM_PREFETCH("[%[outptr0], #96]") - "VMLA.f32 q5, q1, %q[av]\n" - "VST1.32 {d8-d11}, [%[outptr2]]!\n" ASM_PREFETCH("[%[outptr1], #96]") "VMLA.f32 q6, q2, %q[av]\n" ASM_PREFETCH("[%[outptr2], #96]") - "VMLA.f32 q7, q3, %q[av]\n" - "VST1.32 {d12-d15}, [%[outptr3]]!\n" - - // Rows 4-5 - "VLD1.32 {d8-d11}, [%[outptr4]]\n" - "VMUL.f32 q4, q4, %q[bv]\n" - "VLD1.32 {d12-d15}, [%[outptr5]]\n" - "VMUL.f32 q5, q5, %q[bv]\n" - "VLD1.32 {d0-d3}, [%[inptr]]!\n" - "VMUL.f32 q6, q6, %q[bv]\n" - "VLD1.32 {d4-d7}, [%[inptr]]!\n" - "VMUL.f32 q7, q7, %q[bv]\n" - - "VMLA.f32 q4, q0, %q[av]\n" ASM_PREFETCH("[%[outptr3], #96]") - "VMLA.f32 q5, q1, %q[av]\n" - "VST1.32 {d8-d11}, [%[outptr4]]!\n" ASM_PREFETCH("[%[outptr4], #96]") "VMLA.f32 q6, q2, %q[av]\n" ASM_PREFETCH("[%[outptr5], #128]") - "VMLA.f32 q7, q3, %q[av]\n" - "VST1.32 {d12-d15}, [%[outptr5]]!\n" - : [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3), - [outptr4] "+r"(outptr4), [outptr5] "+r"(outptr5), [inptr] "+r"(inptr) - : [av] "w"(av), [bv] "w"(bv) - : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); } } } diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp index 3b59a43c52..f6befa2d14 100644 --- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp +++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp @@ -25,9 +25,8 @@ #ifdef __aarch64__ -template <> -inline void MergeResults<12, 8>(float *out, const float *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const float alpha, const float beta) -{ +template<> +inline void MergeResults<12, 8>(float *out, const float *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const float alpha, const float beta) { const float *inptr = in; prefetch_6x(inptr); prefetch_6x(inptr + 96); @@ -35,8 +34,7 @@ inline void MergeResults<12, 8>(float *out, const float *in, const int ldout, co float32x4_t av = vdupq_n_f32(alpha); float32x4_t bv = vdupq_n_f32(beta); - for(int y = y0; y < ymax; y += 8) - { + for (int y=y0; y(float *out, const float *in, const int ldout, co prefetch_2x(outptr6); prefetch_2x(outptr7); - for(int i = x0; i < xmax; i += 12) - { + for (int i=x0; i= ymax) - { - switch((y + 7) - ymax) - { + if ((y+7) >= ymax) { + switch ((y + 7) - ymax) { case 6: outptr1 = dummyres; case 5: @@ -87,147 +82,259 @@ inline void MergeResults<12, 8>(float *out, const float *in, const int ldout, co } } - /* For ragged X, manually copy over the valid results. */ - if((i + 11) >= xmax) - { - for(int xi = 0; xi < 12; xi++) - { - if((i + xi) < xmax) - { - *outptr0 = (alpha * inptr[xi]) + (*outptr0 * beta); - outptr0++; - *outptr1 = (alpha * inptr[xi + 12]) + (*outptr1 * beta); - outptr1++; - *outptr2 = (alpha * inptr[xi + 24]) + (*outptr2 * beta); - outptr2++; - *outptr3 = (alpha * inptr[xi + 36]) + (*outptr3 * beta); - outptr3++; - *outptr4 = (alpha * inptr[xi + 48]) + (*outptr4 * beta); - outptr4++; - *outptr5 = (alpha * inptr[xi + 60]) + (*outptr5 * beta); - outptr5++; - *outptr6 = (alpha * inptr[xi + 72]) + (*outptr6 * beta); - outptr6++; - *outptr7 = (alpha * inptr[xi + 84]) + (*outptr7 * beta); - outptr7++; + if (beta==0.0f) { + /* If beta==0, don't read the original input at all. */ + + /* For ragged X, manually copy over the valid results. */ + if ((i+11) >= xmax) { + for (int xi=0; xi<12; xi++) { + if ((i+xi) < xmax) { + *outptr0 = (alpha * inptr[xi]); + outptr0++; + *outptr1 = (alpha * inptr[xi + 12]); + outptr1++; + *outptr2 = (alpha * inptr[xi + 24]); + outptr2++; + *outptr3 = (alpha * inptr[xi + 36]); + outptr3++; + *outptr4 = (alpha * inptr[xi + 48]); + outptr4++; + *outptr5 = (alpha * inptr[xi + 60]); + outptr5++; + *outptr6 = (alpha * inptr[xi + 72]); + outptr6++; + *outptr7 = (alpha * inptr[xi + 84]); + outptr7++; + } } + inptr += 96; + } else { + /* Optimized routine to copy an entire block */ + __asm __volatile ( + // Rows 0-1 + "LDP q0, q1, [%[inptr]]\n" + "FMUL v16.4s, v0.4s, %[av].4s\n" + "LDP q2, q3, [%[inptr], #32]\n" + "FMUL v17.4s, v1.4s, %[av].4s\n" + "LDP q4, q5, [%[inptr], #64]\n" + "FMUL v18.4s, v2.4s, %[av].4s\n" + "STP q16, q17, [%[outptr0]], #32\n" + ASM_PREFETCH("[%[inptr], #768]") + "FMUL v19.4s, v3.4s, %[av].4s\n" + "STR q18, [%[outptr0]], #16\n" + "FMUL v20.4s, v4.4s, %[av].4s\n" + "STP q19, q20, [%[outptr1]], #32\n" + ASM_PREFETCH("[%[inptr], #832]") + "FMUL v21.4s, v5.4s, %[av].4s\n" + "STR q21, [%[outptr1]], #16\n" + + // Rows 2-3 + "LDP q0, q1, [%[inptr], #96]\n" + "FMUL v16.4s, v0.4s, %[av].4s\n" + "LDP q2, q3, [%[inptr], #128]\n" + "FMUL v17.4s, v1.4s, %[av].4s\n" + "LDP q4, q5, [%[inptr], #160]\n" + "FMUL v18.4s, v2.4s, %[av].4s\n" + "STP q16, q17, [%[outptr2]], #32\n" + ASM_PREFETCH("[%[inptr], #896]") + "FMUL v19.4s, v3.4s, %[av].4s\n" + "STR q18, [%[outptr2]], #16\n" + "FMUL v20.4s, v4.4s, %[av].4s\n" + "STP q19, q20, [%[outptr3]], #32\n" + ASM_PREFETCH("[%[inptr], #1024]") + "FMUL v21.4s, v5.4s, %[av].4s\n" + "STR q21, [%[outptr3]], #16\n" + + // Rows 4-5 + "LDP q0, q1, [%[inptr], #192]\n" + "FMUL v16.4s, v0.4s, %[av].4s\n" + "LDP q2, q3, [%[inptr], #224]\n" + "FMUL v17.4s, v1.4s, %[av].4s\n" + "LDP q4, q5, [%[inptr], #256]\n" + "FMUL v18.4s, v2.4s, %[av].4s\n" + "STP q16, q17, [%[outptr4]], #32\n" + ASM_PREFETCH("[%[inptr], #960]") + "FMUL v19.4s, v3.4s, %[av].4s\n" + "STR q18, [%[outptr4]], #16\n" + "FMUL v20.4s, v4.4s, %[av].4s\n" + "STP q19, q20, [%[outptr5]], #32\n" + ASM_PREFETCH("[%[inptr], #1088]") + "FMUL v21.4s, v5.4s, %[av].4s\n" + "STR q21, [%[outptr5]], #16\n" + + // Rows 6-7 + "LDP q0, q1, [%[inptr], #288]\n" + "FMUL v16.4s, v0.4s, %[av].4s\n" + "LDP q2, q3, [%[inptr], #320]\n" + "FMUL v17.4s, v1.4s, %[av].4s\n" + "LDP q4, q5, [%[inptr], #352]\n" + "FMUL v18.4s, v2.4s, %[av].4s\n" + "STP q16, q17, [%[outptr6]], #32\n" + "FMUL v19.4s, v3.4s, %[av].4s\n" + "STR q18, [%[outptr6]], #16\n" + "FMUL v20.4s, v4.4s, %[av].4s\n" + "STP q19, q20, [%[outptr7]], #32\n" + "FMUL v21.4s, v5.4s, %[av].4s\n" + "STR q21, [%[outptr7]], #16\n" + "ADD %[inptr], %[inptr], #384\n" + : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3), + [outptr4] "+r" (outptr4), [outptr5] "+r" (outptr5), [outptr6] "+r" (outptr6), [outptr7] "+r" (outptr7), + [inptr] "+r" (inptr) + : [av] "w" (av), [bv] "w" (bv) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v16", "v17", "v18", "v19", "v20", "v21" + ); + } + } else { + /* For ragged X, manually copy over the valid results. */ + if ((i+11) >= xmax) { + for (int xi=0; xi<12; xi++) { + if ((i+xi) < xmax) { + *outptr0 = (alpha * inptr[xi]) + (*outptr0 * beta); + outptr0++; + *outptr1 = (alpha * inptr[xi + 12]) + (*outptr1 * beta); + outptr1++; + *outptr2 = (alpha * inptr[xi + 24]) + (*outptr2 * beta); + outptr2++; + *outptr3 = (alpha * inptr[xi + 36]) + (*outptr3 * beta); + outptr3++; + *outptr4 = (alpha * inptr[xi + 48]) + (*outptr4 * beta); + outptr4++; + *outptr5 = (alpha * inptr[xi + 60]) + (*outptr5 * beta); + outptr5++; + *outptr6 = (alpha * inptr[xi + 72]) + (*outptr6 * beta); + outptr6++; + *outptr7 = (alpha * inptr[xi + 84]) + (*outptr7 * beta); + outptr7++; + } + } + inptr += 96; + } else { + /* Optimized routine to copy an entire block */ + __asm __volatile ( + // Rows 0-1 + "LDP q16, q17, [%[outptr0]]\n" + "FMUL v16.4s, v16.4s, %[bv].4s\n" + "LDR q18, [%[outptr0], #32]\n" + "FMUL v17.4s, v17.4s, %[bv].4s\n" + "LDP q19, q20, [%[outptr1]]\n" + "FMUL v18.4s, v18.4s, %[bv].4s\n" + "LDR q21, [%[outptr1], #32]\n" + ASM_PREFETCH("[%[inptr], #768]") + "FMUL v19.4s, v19.4s, %[bv].4s\n" + "LDP q0, q1, [%[inptr]]\n" + "FMUL v20.4s, v20.4s, %[bv].4s\n" + "LDP q2, q3, [%[inptr], #32]\n" + "FMUL v21.4s, v21.4s, %[bv].4s\n" + "LDP q4, q5, [%[inptr], #64]\n" + "FMLA v16.4s, v0.4s, %[av].4s\n" + ASM_PREFETCH("[%[inptr], #832]") + "FMLA v17.4s, v1.4s, %[av].4s\n" + "STP q16, q17, [%[outptr0]], #32\n" + "FMLA v18.4s, v2.4s, %[av].4s\n" + "STR q18, [%[outptr0]], #16\n" + "FMLA v19.4s, v3.4s, %[av].4s\n" + ASM_PREFETCH("[%[inptr], #896]") + "FMLA v20.4s, v4.4s, %[av].4s\n" + "STP q19, q20, [%[outptr1]], #32\n" + "FMLA v21.4s, v5.4s, %[av].4s\n" + "STR q21, [%[outptr1]], #16\n" + + // Rows 2-3 + "LDP q16, q17, [%[outptr2]]\n" + "FMUL v16.4s, v16.4s, %[bv].4s\n" + "LDR q18, [%[outptr2], #32]\n" + "FMUL v17.4s, v17.4s, %[bv].4s\n" + "LDP q19, q20, [%[outptr3]]\n" + "FMUL v18.4s, v18.4s, %[bv].4s\n" + "LDR q21, [%[outptr3], #32]\n" + ASM_PREFETCH("[%[inptr], #960]") + "FMUL v19.4s, v19.4s, %[bv].4s\n" + "LDP q0, q1, [%[inptr], #96]\n" + "FMUL v20.4s, v20.4s, %[bv].4s\n" + "LDP q2, q3, [%[inptr], #128]\n" + "FMUL v21.4s, v21.4s, %[bv].4s\n" + "LDP q4, q5, [%[inptr], #160]\n" + "FMLA v16.4s, v0.4s, %[av].4s\n" + ASM_PREFETCH("[%[inptr], #1024]") + "FMLA v17.4s, v1.4s, %[av].4s\n" + "STP q16, q17, [%[outptr2]], #32\n" + "FMLA v18.4s, v2.4s, %[av].4s\n" + "STR q18, [%[outptr2]], #16\n" + "FMLA v19.4s, v3.4s, %[av].4s\n" + ASM_PREFETCH("[%[inptr], #1088]") + "FMLA v20.4s, v4.4s, %[av].4s\n" + "STP q19, q20, [%[outptr3]], #32\n" + "FMLA v21.4s, v5.4s, %[av].4s\n" + "STR q21, [%[outptr3]], #16\n" + + // Rows 4-5 + ASM_PREFETCH("[%[outptr0], #80]") + "LDP q16, q17, [%[outptr4]]\n" + "FMUL v16.4s, v16.4s, %[bv].4s\n" + "LDR q18, [%[outptr4], #32]\n" + "FMUL v17.4s, v17.4s, %[bv].4s\n" + "LDP q19, q20, [%[outptr5]]\n" + "FMUL v18.4s, v18.4s, %[bv].4s\n" + "LDR q21, [%[outptr5], #32]\n" + ASM_PREFETCH("[%[outptr1], #80]") + "FMUL v19.4s, v19.4s, %[bv].4s\n" + "LDP q0, q1, [%[inptr], #192]\n" + "FMUL v20.4s, v20.4s, %[bv].4s\n" + "LDP q2, q3, [%[inptr], #224]\n" + "FMUL v21.4s, v21.4s, %[bv].4s\n" + "LDP q4, q5, [%[inptr], #256]\n" + "FMLA v16.4s, v0.4s, %[av].4s\n" + ASM_PREFETCH("[%[outptr2], #80]") + "FMLA v17.4s, v1.4s, %[av].4s\n" + "STP q16, q17, [%[outptr4]], #32\n" + "FMLA v18.4s, v2.4s, %[av].4s\n" + "STR q18, [%[outptr4]], #16\n" + "FMLA v19.4s, v3.4s, %[av].4s\n" + ASM_PREFETCH("[%[outptr3], #80]") + "FMLA v20.4s, v4.4s, %[av].4s\n" + "STP q19, q20, [%[outptr5]], #32\n" + "FMLA v21.4s, v5.4s, %[av].4s\n" + "STR q21, [%[outptr5]], #16\n" + + // Rows 6-7 + ASM_PREFETCH("[%[outptr4], #80]") + "LDP q16, q17, [%[outptr6]]\n" + "FMUL v16.4s, v16.4s, %[bv].4s\n" + "LDR q18, [%[outptr6], #32]\n" + "FMUL v17.4s, v17.4s, %[bv].4s\n" + "LDP q19, q20, [%[outptr7]]\n" + "FMUL v18.4s, v18.4s, %[bv].4s\n" + "LDR q21, [%[outptr7], #32]\n" + ASM_PREFETCH("[%[outptr5], #80]") + "FMUL v19.4s, v19.4s, %[bv].4s\n" + "LDP q0, q1, [%[inptr], #288]\n" + "FMUL v20.4s, v20.4s, %[bv].4s\n" + "LDP q2, q3, [%[inptr], #320]\n" + "FMUL v21.4s, v21.4s, %[bv].4s\n" + "LDP q4, q5, [%[inptr], #352]\n" + "FMLA v16.4s, v0.4s, %[av].4s\n" + ASM_PREFETCH("[%[outptr6], #128]") + "FMLA v17.4s, v1.4s, %[av].4s\n" + "STP q16, q17, [%[outptr6]], #32\n" + "FMLA v18.4s, v2.4s, %[av].4s\n" + "STR q18, [%[outptr6]], #16\n" + "FMLA v19.4s, v3.4s, %[av].4s\n" + ASM_PREFETCH("[%[outptr7], #128]") + "FMLA v20.4s, v4.4s, %[av].4s\n" + "STP q19, q20, [%[outptr7]], #32\n" + "FMLA v21.4s, v5.4s, %[av].4s\n" + "STR q21, [%[outptr7]], #16\n" + "ADD %[inptr], %[inptr], #384\n" + : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3), + [outptr4] "+r" (outptr4), [outptr5] "+r" (outptr5), [outptr6] "+r" (outptr6), [outptr7] "+r" (outptr7), + [inptr] "+r" (inptr) + : [av] "w" (av), [bv] "w" (bv) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v16", "v17", "v18", "v19", "v20", "v21" + ); } - inptr += 96; - } - else - { - /* Optimized routine to copy an entire block */ - __asm __volatile( - // Rows 0-1 - "LDP q16, q17, [%[outptr0]]\n" - "FMUL v16.4s, v16.4s, %[bv].4s\n" - "LDR q18, [%[outptr0], #32]\n" - "FMUL v17.4s, v17.4s, %[bv].4s\n" - "LDP q19, q20, [%[outptr1]]\n" - "FMUL v18.4s, v18.4s, %[bv].4s\n" - "LDR q21, [%[outptr1], #32]\n" ASM_PREFETCH("[%[inptr], #768]") - "FMUL v19.4s, v19.4s, %[bv].4s\n" - "LDP q0, q1, [%[inptr]]\n" - "FMUL v20.4s, v20.4s, %[bv].4s\n" - "LDP q2, q3, [%[inptr], #32]\n" - "FMUL v21.4s, v21.4s, %[bv].4s\n" - "LDP q4, q5, [%[inptr], #64]\n" - "FMLA v16.4s, v0.4s, %[av].4s\n" ASM_PREFETCH("[%[inptr], #832]") - "FMLA v17.4s, v1.4s, %[av].4s\n" - "STP q16, q17, [%[outptr0]], #32\n" - "FMLA v18.4s, v2.4s, %[av].4s\n" - "STR q18, [%[outptr0]], #16\n" - "FMLA v19.4s, v3.4s, %[av].4s\n" ASM_PREFETCH("[%[inptr], #896]") - "FMLA v20.4s, v4.4s, %[av].4s\n" - "STP q19, q20, [%[outptr1]], #32\n" - "FMLA v21.4s, v5.4s, %[av].4s\n" - "STR q21, [%[outptr1]], #16\n" - - // Rows 2-3 - "LDP q16, q17, [%[outptr2]]\n" - "FMUL v16.4s, v16.4s, %[bv].4s\n" - "LDR q18, [%[outptr2], #32]\n" - "FMUL v17.4s, v17.4s, %[bv].4s\n" - "LDP q19, q20, [%[outptr3]]\n" - "FMUL v18.4s, v18.4s, %[bv].4s\n" - "LDR q21, [%[outptr3], #32]\n" ASM_PREFETCH("[%[inptr], #960]") - "FMUL v19.4s, v19.4s, %[bv].4s\n" - "LDP q0, q1, [%[inptr], #96]\n" - "FMUL v20.4s, v20.4s, %[bv].4s\n" - "LDP q2, q3, [%[inptr], #128]\n" - "FMUL v21.4s, v21.4s, %[bv].4s\n" - "LDP q4, q5, [%[inptr], #160]\n" - "FMLA v16.4s, v0.4s, %[av].4s\n" ASM_PREFETCH("[%[inptr], #1024]") - "FMLA v17.4s, v1.4s, %[av].4s\n" - "STP q16, q17, [%[outptr2]], #32\n" - "FMLA v18.4s, v2.4s, %[av].4s\n" - "STR q18, [%[outptr2]], #16\n" - "FMLA v19.4s, v3.4s, %[av].4s\n" ASM_PREFETCH("[%[inptr], #1088]") - "FMLA v20.4s, v4.4s, %[av].4s\n" - "STP q19, q20, [%[outptr3]], #32\n" - "FMLA v21.4s, v5.4s, %[av].4s\n" - "STR q21, [%[outptr3]], #16\n" - - // Rows 4-5 - ASM_PREFETCH("[%[outptr0], #80]") - "LDP q16, q17, [%[outptr4]]\n" - "FMUL v16.4s, v16.4s, %[bv].4s\n" - "LDR q18, [%[outptr4], #32]\n" - "FMUL v17.4s, v17.4s, %[bv].4s\n" - "LDP q19, q20, [%[outptr5]]\n" - "FMUL v18.4s, v18.4s, %[bv].4s\n" - "LDR q21, [%[outptr5], #32]\n" ASM_PREFETCH("[%[outptr1], #80]") - "FMUL v19.4s, v19.4s, %[bv].4s\n" - "LDP q0, q1, [%[inptr], #192]\n" - "FMUL v20.4s, v20.4s, %[bv].4s\n" - "LDP q2, q3, [%[inptr], #224]\n" - "FMUL v21.4s, v21.4s, %[bv].4s\n" - "LDP q4, q5, [%[inptr], #256]\n" - "FMLA v16.4s, v0.4s, %[av].4s\n" ASM_PREFETCH("[%[outptr2], #80]") - "FMLA v17.4s, v1.4s, %[av].4s\n" - "STP q16, q17, [%[outptr4]], #32\n" - "FMLA v18.4s, v2.4s, %[av].4s\n" - "STR q18, [%[outptr4]], #16\n" - "FMLA v19.4s, v3.4s, %[av].4s\n" ASM_PREFETCH("[%[outptr3], #80]") - "FMLA v20.4s, v4.4s, %[av].4s\n" - "STP q19, q20, [%[outptr5]], #32\n" - "FMLA v21.4s, v5.4s, %[av].4s\n" - "STR q21, [%[outptr5]], #16\n" - - // Rows 6-7 - ASM_PREFETCH("[%[outptr4], #80]") - "LDP q16, q17, [%[outptr6]]\n" - "FMUL v16.4s, v16.4s, %[bv].4s\n" - "LDR q18, [%[outptr6], #32]\n" - "FMUL v17.4s, v17.4s, %[bv].4s\n" - "LDP q19, q20, [%[outptr7]]\n" - "FMUL v18.4s, v18.4s, %[bv].4s\n" - "LDR q21, [%[outptr7], #32]\n" ASM_PREFETCH("[%[outptr5], #80]") - "FMUL v19.4s, v19.4s, %[bv].4s\n" - "LDP q0, q1, [%[inptr], #288]\n" - "FMUL v20.4s, v20.4s, %[bv].4s\n" - "LDP q2, q3, [%[inptr], #320]\n" - "FMUL v21.4s, v21.4s, %[bv].4s\n" - "LDP q4, q5, [%[inptr], #352]\n" - "FMLA v16.4s, v0.4s, %[av].4s\n" ASM_PREFETCH("[%[outptr6], #128]") - "FMLA v17.4s, v1.4s, %[av].4s\n" - "STP q16, q17, [%[outptr6]], #32\n" - "FMLA v18.4s, v2.4s, %[av].4s\n" - "STR q18, [%[outptr6]], #16\n" - "FMLA v19.4s, v3.4s, %[av].4s\n" ASM_PREFETCH("[%[outptr7], #128]") - "FMLA v20.4s, v4.4s, %[av].4s\n" - "STP q19, q20, [%[outptr7]], #32\n" - "FMLA v21.4s, v5.4s, %[av].4s\n" - "STR q21, [%[outptr7]], #16\n" - "ADD %[inptr], %[inptr], #384\n" - : [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3), - [outptr4] "+r"(outptr4), [outptr5] "+r"(outptr5), [outptr6] "+r"(outptr6), [outptr7] "+r"(outptr7), - [inptr] "+r"(inptr) - : [av] "w"(av), [bv] "w"(bv) - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v16", "v17", "v18", "v19", "v20", "v21"); } } } } -#endif // __aarch64__ \ No newline at end of file +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp index 9708fe189d..e7a7521823 100644 --- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp +++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp @@ -28,9 +28,8 @@ #include -template <> -inline void MergeResults<12, 8>(__fp16 *out, const float *in, int ldout, int y0, int ymax, int x0, int xmax, const __fp16 alpha, const __fp16 beta) -{ +template<> +inline void MergeResults<12,8>(__fp16 *out, const float *in, int ldout, int y0, int ymax, int x0, int xmax, const __fp16 alpha, const __fp16 beta) { const float *inptr = in; prefetch_6x(inptr); prefetch_6x(inptr + 24); @@ -38,8 +37,7 @@ inline void MergeResults<12, 8>(__fp16 *out, const float *in, int ldout, int y0, float32x4_t av = vdupq_n_f32(alpha); float32x4_t bv = vdupq_n_f32(beta); - for(int y = y0; y < ymax; y += 8) - { + for (int y=y0; y(__fp16 *out, const float *in, int ldout, int y0, prefetch_2x(outptr6); prefetch_2x(outptr7); - for(int i = x0; i < xmax; i += 12) - { + for (int i=x0; i= ymax) - { - switch((y + 7) - ymax) - { + if ((y+7) >= ymax) { + switch ((y + 7) - ymax) { case 6: outptr1 = dummyres; case 5: @@ -90,182 +85,335 @@ inline void MergeResults<12, 8>(__fp16 *out, const float *in, int ldout, int y0, } } - /* For ragged X, manually copy over the valid results. */ - if((i + 11) >= xmax) - { - for(int xi = 0; xi < 12; xi++) - { - if((i + xi) < xmax) - { - *outptr0 = (alpha * inptr[xi]) + (*outptr0 * beta); - outptr0++; - *outptr1 = (alpha * inptr[xi + 12]) + (*outptr1 * beta); - outptr1++; - *outptr2 = (alpha * inptr[xi + 24]) + (*outptr2 * beta); - outptr2++; - *outptr3 = (alpha * inptr[xi + 36]) + (*outptr3 * beta); - outptr3++; - *outptr4 = (alpha * inptr[xi + 48]) + (*outptr4 * beta); - outptr4++; - *outptr5 = (alpha * inptr[xi + 60]) + (*outptr5 * beta); - outptr5++; - *outptr6 = (alpha * inptr[xi + 72]) + (*outptr6 * beta); - outptr6++; - *outptr7 = (alpha * inptr[xi + 84]) + (*outptr7 * beta); - outptr7++; + if (beta == ((__fp16)0.0f)) { + /* If beta==0, don't read the output. */ + /* For ragged X, manually copy over the valid results. */ + if ((i+11) >= xmax) { + for (int xi=0; xi<12; xi++) { + if ((i+xi) < xmax) { + *outptr0 = (alpha * inptr[xi]); + outptr0++; + *outptr1 = (alpha * inptr[xi + 12]); + outptr1++; + *outptr2 = (alpha * inptr[xi + 24]); + outptr2++; + *outptr3 = (alpha * inptr[xi + 36]); + outptr3++; + *outptr4 = (alpha * inptr[xi + 48]); + outptr4++; + *outptr5 = (alpha * inptr[xi + 60]); + outptr5++; + *outptr6 = (alpha * inptr[xi + 72]); + outptr6++; + *outptr7 = (alpha * inptr[xi + 84]); + outptr7++; + } } + inptr += 96; + } else { + /* Optimized routine to copy an entire block */ + __asm __volatile ( + // Rows 0-1 + "LDP q0, q1, [%[inptr]]\n" + "LDP q2, q3, [%[inptr], #32]\n" + "LDP q4, q5, [%[inptr], #64]\n" + "FMUL v16.4s, v0.4s, %[av].4s\n" + ASM_PREFETCH("[%[inptr], #768]") + "FMUL v17.4s, v1.4s, %[av].4s\n" + ASM_PREFETCH("[%[inptr], #832]") + "FCVTN v16.4h, v16.4s\n" + ASM_PREFETCH("[%[inptr], #896]") + "FCVTN2 v16.8h, v17.4s\n" + ASM_PREFETCH("[%[inptr], #960]") + "FMUL v18.4s, v2.4s, %[av].4s\n" + "STR q16, [%[outptr0]], #16\n" + "FCVTN v18.4h, v18.4s\n" + "STR d18, [%[outptr0]], #8\n" + "FMUL v19.4s, v3.4s, %[av].4s\n" + "FMUL v20.4s, v4.4s, %[av].4s\n" + "FCVTN v19.4h, v19.4s\n" + "FCVTN2 v19.8h, v20.4s\n" + "STR q19, [%[outptr1]], #16\n" + "FMUL v21.4s, v5.4s, %[av].4s\n" + "FCVTN v21.4h, v21.4s\n" + "STR d21, [%[outptr1]], #8\n" + + // Rows 2-3 + "LDP q0, q1, [%[inptr], #96]\n" + "LDP q2, q3, [%[inptr], #128]\n" + "LDP q4, q5, [%[inptr], #160]\n" + "FMUL v16.4s, v0.4s, %[av].4s\n" + ASM_PREFETCH("[%[inptr], #1024]") + "FMUL v17.4s, v1.4s, %[av].4s\n" + ASM_PREFETCH("[%[inptr], #1088]") + "FCVTN v16.4h, v16.4s\n" + ASM_PREFETCH("[%[outptr0], #64]") + "FCVTN2 v16.8h, v17.4s\n" + ASM_PREFETCH("[%[outptr1], #64]") + "FMUL v18.4s, v2.4s, %[av].4s\n" + "STR q16, [%[outptr2]], #16\n" + "FCVTN v18.4h, v18.4s\n" + "STR d18, [%[outptr2]], #8\n" + "FMUL v19.4s, v3.4s, %[av].4s\n" + "FMUL v20.4s, v4.4s, %[av].4s\n" + "FCVTN v19.4h, v19.4s\n" + "FCVTN2 v19.8h, v20.4s\n" + "STR q19, [%[outptr3]], #16\n" + "FMUL v21.4s, v5.4s, %[av].4s\n" + "FCVTN v21.4h, v21.4s\n" + "STR d21, [%[outptr3]], #8\n" + + // Rows 4-5 + "LDP q0, q1, [%[inptr], #192]\n" + "LDP q2, q3, [%[inptr], #224]\n" + "LDP q4, q5, [%[inptr], #256]\n" + "FMUL v16.4s, v0.4s, %[av].4s\n" + "FMUL v17.4s, v1.4s, %[av].4s\n" + ASM_PREFETCH("[%[outptr2], #64]") + "FCVTN v16.4h, v16.4s\n" + ASM_PREFETCH("[%[outptr3], #64]") + "FCVTN2 v16.8h, v17.4s\n" + ASM_PREFETCH("[%[outptr4], #88]") + "FMUL v18.4s, v2.4s, %[av].4s\n" + "STR q16, [%[outptr4]], #16\n" + "FCVTN v18.4h, v18.4s\n" + "STR d18, [%[outptr4]], #8\n" + "FMUL v19.4s, v3.4s, %[av].4s\n" + "FMUL v20.4s, v4.4s, %[av].4s\n" + "FCVTN v19.4h, v19.4s\n" + "FCVTN2 v19.8h, v20.4s\n" + "STR q19, [%[outptr5]], #16\n" + "FMUL v21.4s, v5.4s, %[av].4s\n" + "FCVTN v21.4h, v21.4s\n" + "STR d21, [%[outptr5]], #8\n" + + // Rows 6-7 + "LDP q0, q1, [%[inptr], #288]\n" + "LDP q2, q3, [%[inptr], #320]\n" + "LDP q4, q5, [%[inptr], #352]\n" + "FMUL v16.4s, v0.4s, %[av].4s\n" + "FMUL v17.4s, v1.4s, %[av].4s\n" + ASM_PREFETCH("[%[outptr5], #64]") + "FCVTN v16.4h, v16.4s\n" + ASM_PREFETCH("[%[outptr6], #88]") + "FCVTN2 v16.8h, v17.4s\n" + ASM_PREFETCH("[%[outptr7], #88]") + "FMUL v18.4s, v2.4s, %[av].4s\n" + "STR q16, [%[outptr6]], #16\n" + "FCVTN v18.4h, v18.4s\n" + "STR d18, [%[outptr6]], #8\n" + "FMUL v19.4s, v3.4s, %[av].4s\n" + "FMUL v20.4s, v4.4s, %[av].4s\n" + "FCVTN v19.4h, v19.4s\n" + "FCVTN2 v19.8h, v20.4s\n" + "STR q19, [%[outptr7]], #16\n" + "FMUL v21.4s, v5.4s, %[av].4s\n" + "FCVTN v21.4h, v21.4s\n" + "STR d21, [%[outptr7]], #8\n" + "ADD %[inptr], %[inptr], #384\n" + : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3), + [outptr4] "+r" (outptr4), [outptr5] "+r" (outptr5), [outptr6] "+r" (outptr6), [outptr7] "+r" (outptr7), + [inptr] "+r" (inptr) + : [av] "w" (av), [bv] "w" (bv) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v16", "v17", "v18", "v19", "v20", "v21" + ); } - inptr += 96; - } - else - { - /* Optimized routine to copy an entire block */ - __asm __volatile( - // Rows 0-1 - "LDR q16, [%[outptr0]]\n" - "FCVTL2 v17.4s, v16.8h\n" - "LDR d18, [%[outptr0], #16]\n" - "FCVTL v16.4s, v16.4h\n" - "LDR q19, [%[outptr1]]\n" - "FMUL v17.4s, v17.4s, %[bv].4s\n" - "LDR d21, [%[outptr1], #16]\n" - "FMUL v16.4s, v16.4s, %[bv].4s\n" - "LDP q0, q1, [%[inptr]]\n" - "FCVTL v18.4s, v18.4h\n" - "LDP q2, q3, [%[inptr], #32]\n" - "FCVTL2 v20.4s, v19.8h\n" - "LDP q4, q5, [%[inptr], #64]\n" - "FCVTL v19.4s, v19.4h\n" ASM_PREFETCH("[%[inptr], #768]") "FCVTL v21.4s, v21.4h\n" ASM_PREFETCH("[%[inptr], #832]") "FMUL v18.4s, v18.4s, %[bv].4s\n" ASM_PREFETCH("[%[inptr], #896]") - "FMUL v20.4s, v20.4s, %[bv].4s\n" ASM_PREFETCH("[%[inptr], #960]") - "FMUL v19.4s, v19.4s, %[bv].4s\n" - "FMUL v21.4s, v21.4s, %[bv].4s\n" - "FMLA v16.4s, v0.4s, %[av].4s\n" - "FMLA v17.4s, v1.4s, %[av].4s\n" - "FCVTN v16.4h, v16.4s\n" - "FCVTN2 v16.8h, v17.4s\n" - "FMLA v18.4s, v2.4s, %[av].4s\n" - "STR q16, [%[outptr0]], #16\n" - "FCVTN v18.4h, v18.4s\n" - "STR d18, [%[outptr0]], #8\n" - "FMLA v19.4s, v3.4s, %[av].4s\n" - "FMLA v20.4s, v4.4s, %[av].4s\n" - "FCVTN v19.4h, v19.4s\n" - "FCVTN2 v19.8h, v20.4s\n" - "STR q19, [%[outptr1]], #16\n" - "FMLA v21.4s, v5.4s, %[av].4s\n" - "FCVTN v21.4h, v21.4s\n" - "STR d21, [%[outptr1]], #8\n" + } else { + /* For ragged X, manually copy over the valid results. */ + if ((i+11) >= xmax) { + for (int xi=0; xi<12; xi++) { + if ((i+xi) < xmax) { + *outptr0 = (alpha * inptr[xi]) + (*outptr0 * beta); + outptr0++; + *outptr1 = (alpha * inptr[xi + 12]) + (*outptr1 * beta); + outptr1++; + *outptr2 = (alpha * inptr[xi + 24]) + (*outptr2 * beta); + outptr2++; + *outptr3 = (alpha * inptr[xi + 36]) + (*outptr3 * beta); + outptr3++; + *outptr4 = (alpha * inptr[xi + 48]) + (*outptr4 * beta); + outptr4++; + *outptr5 = (alpha * inptr[xi + 60]) + (*outptr5 * beta); + outptr5++; + *outptr6 = (alpha * inptr[xi + 72]) + (*outptr6 * beta); + outptr6++; + *outptr7 = (alpha * inptr[xi + 84]) + (*outptr7 * beta); + outptr7++; + } + } + inptr += 96; + } else { + /* Optimized routine to copy an entire block */ + __asm __volatile ( + // Rows 0-1 + "LDR q16, [%[outptr0]]\n" + "FCVTL2 v17.4s, v16.8h\n" + "LDR d18, [%[outptr0], #16]\n" + "FCVTL v16.4s, v16.4h\n" + "LDR q19, [%[outptr1]]\n" + "FMUL v17.4s, v17.4s, %[bv].4s\n" + "LDR d21, [%[outptr1], #16]\n" + "FMUL v16.4s, v16.4s, %[bv].4s\n" + "LDP q0, q1, [%[inptr]]\n" + "FCVTL v18.4s, v18.4h\n" + "LDP q2, q3, [%[inptr], #32]\n" + "FCVTL2 v20.4s, v19.8h\n" + "LDP q4, q5, [%[inptr], #64]\n" + "FCVTL v19.4s, v19.4h\n" + ASM_PREFETCH("[%[inptr], #768]") + "FCVTL v21.4s, v21.4h\n" + ASM_PREFETCH("[%[inptr], #832]") + "FMUL v18.4s, v18.4s, %[bv].4s\n" + ASM_PREFETCH("[%[inptr], #896]") + "FMUL v20.4s, v20.4s, %[bv].4s\n" + ASM_PREFETCH("[%[inptr], #960]") + "FMUL v19.4s, v19.4s, %[bv].4s\n" + "FMUL v21.4s, v21.4s, %[bv].4s\n" + "FMLA v16.4s, v0.4s, %[av].4s\n" + "FMLA v17.4s, v1.4s, %[av].4s\n" + "FCVTN v16.4h, v16.4s\n" + "FCVTN2 v16.8h, v17.4s\n" + "FMLA v18.4s, v2.4s, %[av].4s\n" + "STR q16, [%[outptr0]], #16\n" + "FCVTN v18.4h, v18.4s\n" + "STR d18, [%[outptr0]], #8\n" + "FMLA v19.4s, v3.4s, %[av].4s\n" + "FMLA v20.4s, v4.4s, %[av].4s\n" + "FCVTN v19.4h, v19.4s\n" + "FCVTN2 v19.8h, v20.4s\n" + "STR q19, [%[outptr1]], #16\n" + "FMLA v21.4s, v5.4s, %[av].4s\n" + "FCVTN v21.4h, v21.4s\n" + "STR d21, [%[outptr1]], #8\n" - // Rows 2-3 - "LDR q16, [%[outptr2]]\n" - "FCVTL2 v17.4s, v16.8h\n" - "LDR d18, [%[outptr2], #16]\n" - "FCVTL v16.4s, v16.4h\n" - "LDR q19, [%[outptr3]]\n" - "FMUL v17.4s, v17.4s, %[bv].4s\n" - "LDR d21, [%[outptr3], #16]\n" - "FMUL v16.4s, v16.4s, %[bv].4s\n" - "LDP q0, q1, [%[inptr], #96]\n" - "FCVTL v18.4s, v18.4h\n" - "LDP q2, q3, [%[inptr], #128]\n" - "FCVTL2 v20.4s, v19.8h\n" - "LDP q4, q5, [%[inptr], #160]\n" - "FCVTL v19.4s, v19.4h\n" ASM_PREFETCH("[%[inptr], #1024]") "FCVTL v21.4s, v21.4h\n" ASM_PREFETCH("[%[inptr], #1088]") "FMUL v18.4s, v18.4s, %[bv].4s\n" ASM_PREFETCH("[%[outptr0], #64]") - "FMUL v20.4s, v20.4s, %[bv].4s\n" ASM_PREFETCH("[%[outptr1], #64]") - "FMUL v19.4s, v19.4s, %[bv].4s\n" - "FMUL v21.4s, v21.4s, %[bv].4s\n" - "FMLA v16.4s, v0.4s, %[av].4s\n" - "FMLA v17.4s, v1.4s, %[av].4s\n" - "FCVTN v16.4h, v16.4s\n" - "FCVTN2 v16.8h, v17.4s\n" - "FMLA v18.4s, v2.4s, %[av].4s\n" - "STR q16, [%[outptr2]], #16\n" - "FCVTN v18.4h, v18.4s\n" - "STR d18, [%[outptr2]], #8\n" - "FMLA v19.4s, v3.4s, %[av].4s\n" - "FMLA v20.4s, v4.4s, %[av].4s\n" - "FCVTN v19.4h, v19.4s\n" - "FCVTN2 v19.8h, v20.4s\n" - "STR q19, [%[outptr3]], #16\n" - "FMLA v21.4s, v5.4s, %[av].4s\n" - "FCVTN v21.4h, v21.4s\n" - "STR d21, [%[outptr3]], #8\n" + // Rows 2-3 + "LDR q16, [%[outptr2]]\n" + "FCVTL2 v17.4s, v16.8h\n" + "LDR d18, [%[outptr2], #16]\n" + "FCVTL v16.4s, v16.4h\n" + "LDR q19, [%[outptr3]]\n" + "FMUL v17.4s, v17.4s, %[bv].4s\n" + "LDR d21, [%[outptr3], #16]\n" + "FMUL v16.4s, v16.4s, %[bv].4s\n" + "LDP q0, q1, [%[inptr], #96]\n" + "FCVTL v18.4s, v18.4h\n" + "LDP q2, q3, [%[inptr], #128]\n" + "FCVTL2 v20.4s, v19.8h\n" + "LDP q4, q5, [%[inptr], #160]\n" + "FCVTL v19.4s, v19.4h\n" + ASM_PREFETCH("[%[inptr], #1024]") + "FCVTL v21.4s, v21.4h\n" + ASM_PREFETCH("[%[inptr], #1088]") + "FMUL v18.4s, v18.4s, %[bv].4s\n" + ASM_PREFETCH("[%[outptr0], #64]") + "FMUL v20.4s, v20.4s, %[bv].4s\n" + ASM_PREFETCH("[%[outptr1], #64]") + "FMUL v19.4s, v19.4s, %[bv].4s\n" + "FMUL v21.4s, v21.4s, %[bv].4s\n" + "FMLA v16.4s, v0.4s, %[av].4s\n" + "FMLA v17.4s, v1.4s, %[av].4s\n" + "FCVTN v16.4h, v16.4s\n" + "FCVTN2 v16.8h, v17.4s\n" + "FMLA v18.4s, v2.4s, %[av].4s\n" + "STR q16, [%[outptr2]], #16\n" + "FCVTN v18.4h, v18.4s\n" + "STR d18, [%[outptr2]], #8\n" + "FMLA v19.4s, v3.4s, %[av].4s\n" + "FMLA v20.4s, v4.4s, %[av].4s\n" + "FCVTN v19.4h, v19.4s\n" + "FCVTN2 v19.8h, v20.4s\n" + "STR q19, [%[outptr3]], #16\n" + "FMLA v21.4s, v5.4s, %[av].4s\n" + "FCVTN v21.4h, v21.4s\n" + "STR d21, [%[outptr3]], #8\n" - // Rows 4-5 - "LDR q16, [%[outptr4]]\n" - "FCVTL2 v17.4s, v16.8h\n" - "LDR d18, [%[outptr4], #16]\n" - "FCVTL v16.4s, v16.4h\n" - "LDR q19, [%[outptr5]]\n" - "FMUL v17.4s, v17.4s, %[bv].4s\n" - "LDR d21, [%[outptr5], #16]\n" - "FMUL v16.4s, v16.4s, %[bv].4s\n" - "LDP q0, q1, [%[inptr], #192]\n" - "FCVTL v18.4s, v18.4h\n" - "LDP q2, q3, [%[inptr], #224]\n" - "FCVTL2 v20.4s, v19.8h\n" - "LDP q4, q5, [%[inptr], #256]\n" - "FCVTL v19.4s, v19.4h\n" ASM_PREFETCH("[%[outptr2], #64]") "FCVTL v21.4s, v21.4h\n" ASM_PREFETCH("[%[outptr3], #64]") "FMUL v18.4s, v18.4s, %[bv].4s\n" ASM_PREFETCH("[%[outptr4], #88]") - "FMUL v20.4s, v20.4s, %[bv].4s\n" - "FMUL v19.4s, v19.4s, %[bv].4s\n" - "FMUL v21.4s, v21.4s, %[bv].4s\n" - "FMLA v16.4s, v0.4s, %[av].4s\n" - "FMLA v17.4s, v1.4s, %[av].4s\n" - "FCVTN v16.4h, v16.4s\n" - "FCVTN2 v16.8h, v17.4s\n" - "FMLA v18.4s, v2.4s, %[av].4s\n" - "STR q16, [%[outptr4]], #16\n" - "FCVTN v18.4h, v18.4s\n" - "STR d18, [%[outptr4]], #8\n" - "FMLA v19.4s, v3.4s, %[av].4s\n" - "FMLA v20.4s, v4.4s, %[av].4s\n" - "FCVTN v19.4h, v19.4s\n" - "FCVTN2 v19.8h, v20.4s\n" - "STR q19, [%[outptr5]], #16\n" - "FMLA v21.4s, v5.4s, %[av].4s\n" - "FCVTN v21.4h, v21.4s\n" - "STR d21, [%[outptr5]], #8\n" + // Rows 4-5 + "LDR q16, [%[outptr4]]\n" + "FCVTL2 v17.4s, v16.8h\n" + "LDR d18, [%[outptr4], #16]\n" + "FCVTL v16.4s, v16.4h\n" + "LDR q19, [%[outptr5]]\n" + "FMUL v17.4s, v17.4s, %[bv].4s\n" + "LDR d21, [%[outptr5], #16]\n" + "FMUL v16.4s, v16.4s, %[bv].4s\n" + "LDP q0, q1, [%[inptr], #192]\n" + "FCVTL v18.4s, v18.4h\n" + "LDP q2, q3, [%[inptr], #224]\n" + "FCVTL2 v20.4s, v19.8h\n" + "LDP q4, q5, [%[inptr], #256]\n" + "FCVTL v19.4s, v19.4h\n" + ASM_PREFETCH("[%[outptr2], #64]") + "FCVTL v21.4s, v21.4h\n" + ASM_PREFETCH("[%[outptr3], #64]") + "FMUL v18.4s, v18.4s, %[bv].4s\n" + ASM_PREFETCH("[%[outptr4], #88]") + "FMUL v20.4s, v20.4s, %[bv].4s\n" + "FMUL v19.4s, v19.4s, %[bv].4s\n" + "FMUL v21.4s, v21.4s, %[bv].4s\n" + "FMLA v16.4s, v0.4s, %[av].4s\n" + "FMLA v17.4s, v1.4s, %[av].4s\n" + "FCVTN v16.4h, v16.4s\n" + "FCVTN2 v16.8h, v17.4s\n" + "FMLA v18.4s, v2.4s, %[av].4s\n" + "STR q16, [%[outptr4]], #16\n" + "FCVTN v18.4h, v18.4s\n" + "STR d18, [%[outptr4]], #8\n" + "FMLA v19.4s, v3.4s, %[av].4s\n" + "FMLA v20.4s, v4.4s, %[av].4s\n" + "FCVTN v19.4h, v19.4s\n" + "FCVTN2 v19.8h, v20.4s\n" + "STR q19, [%[outptr5]], #16\n" + "FMLA v21.4s, v5.4s, %[av].4s\n" + "FCVTN v21.4h, v21.4s\n" + "STR d21, [%[outptr5]], #8\n" - // Rows 6-7 - "LDR q16, [%[outptr6]]\n" - "FCVTL2 v17.4s, v16.8h\n" - "LDR d18, [%[outptr6], #16]\n" - "FCVTL v16.4s, v16.4h\n" - "LDR q19, [%[outptr7]]\n" - "FMUL v17.4s, v17.4s, %[bv].4s\n" - "LDR d21, [%[outptr7], #16]\n" - "FMUL v16.4s, v16.4s, %[bv].4s\n" - "LDP q0, q1, [%[inptr], #288]\n" - "FCVTL v18.4s, v18.4h\n" - "LDP q2, q3, [%[inptr], #320]\n" - "FCVTL2 v20.4s, v19.8h\n" - "LDP q4, q5, [%[inptr], #352]\n" - "FCVTL v19.4s, v19.4h\n" ASM_PREFETCH("[%[outptr5], #64]") "FCVTL v21.4s, v21.4h\n" ASM_PREFETCH("[%[outptr6], #88]") "FMUL v18.4s, v18.4s, %[bv].4s\n" ASM_PREFETCH("[%[outptr7], #88]") - "FMUL v20.4s, v20.4s, %[bv].4s\n" - "FMUL v19.4s, v19.4s, %[bv].4s\n" - "FMUL v21.4s, v21.4s, %[bv].4s\n" - "FMLA v16.4s, v0.4s, %[av].4s\n" - "FMLA v17.4s, v1.4s, %[av].4s\n" - "FCVTN v16.4h, v16.4s\n" - "FCVTN2 v16.8h, v17.4s\n" - "FMLA v18.4s, v2.4s, %[av].4s\n" - "STR q16, [%[outptr6]], #16\n" - "FCVTN v18.4h, v18.4s\n" - "STR d18, [%[outptr6]], #8\n" - "FMLA v19.4s, v3.4s, %[av].4s\n" - "FMLA v20.4s, v4.4s, %[av].4s\n" - "FCVTN v19.4h, v19.4s\n" - "FCVTN2 v19.8h, v20.4s\n" - "STR q19, [%[outptr7]], #16\n" - "FMLA v21.4s, v5.4s, %[av].4s\n" - "FCVTN v21.4h, v21.4s\n" - "STR d21, [%[outptr7]], #8\n" - "ADD %[inptr], %[inptr], #384\n" - : [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3), - [outptr4] "+r"(outptr4), [outptr5] "+r"(outptr5), [outptr6] "+r"(outptr6), [outptr7] "+r"(outptr7), - [inptr] "+r"(inptr) - : [av] "w"(av), [bv] "w"(bv) - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v16", "v17", "v18", "v19", "v20", "v21"); + // Rows 6-7 + "LDR q16, [%[outptr6]]\n" + "FCVTL2 v17.4s, v16.8h\n" + "LDR d18, [%[outptr6], #16]\n" + "FCVTL v16.4s, v16.4h\n" + "LDR q19, [%[outptr7]]\n" + "FMUL v17.4s, v17.4s, %[bv].4s\n" + "LDR d21, [%[outptr7], #16]\n" + "FMUL v16.4s, v16.4s, %[bv].4s\n" + "LDP q0, q1, [%[inptr], #288]\n" + "FCVTL v18.4s, v18.4h\n" + "LDP q2, q3, [%[inptr], #320]\n" + "FCVTL2 v20.4s, v19.8h\n" + "LDP q4, q5, [%[inptr], #352]\n" + "FCVTL v19.4s, v19.4h\n" + ASM_PREFETCH("[%[outptr5], #64]") + "FCVTL v21.4s, v21.4h\n" + ASM_PREFETCH("[%[outptr6], #88]") + "FMUL v18.4s, v18.4s, %[bv].4s\n" + ASM_PREFETCH("[%[outptr7], #88]") + "FMUL v20.4s, v20.4s, %[bv].4s\n" + "FMUL v19.4s, v19.4s, %[bv].4s\n" + "FMUL v21.4s, v21.4s, %[bv].4s\n" + "FMLA v16.4s, v0.4s, %[av].4s\n" + "FMLA v17.4s, v1.4s, %[av].4s\n" + "FCVTN v16.4h, v16.4s\n" + "FCVTN2 v16.8h, v17.4s\n" + "FMLA v18.4s, v2.4s, %[av].4s\n" + "STR q16, [%[outptr6]], #16\n" + "FCVTN v18.4h, v18.4s\n" + "STR d18, [%[outptr6]], #8\n" + "FMLA v19.4s, v3.4s, %[av].4s\n" + "FMLA v20.4s, v4.4s, %[av].4s\n" + "FCVTN v19.4h, v19.4s\n" + "FCVTN2 v19.8h, v20.4s\n" + "STR q19, [%[outptr7]], #16\n" + "FMLA v21.4s, v5.4s, %[av].4s\n" + "FCVTN v21.4h, v21.4s\n" + "STR d21, [%[outptr7]], #8\n" + "ADD %[inptr], %[inptr], #384\n" + : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3), + [outptr4] "+r" (outptr4), [outptr5] "+r" (outptr5), [outptr6] "+r" (outptr6), [outptr7] "+r" (outptr7), + [inptr] "+r" (inptr) + : [av] "w" (av), [bv] "w" (bv) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v16", "v17", "v18", "v19", "v20", "v21" + ); + } } } } diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_half_24x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_half_24x8.hpp index 08cfc00523..3ed43b10bd 100644 --- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_half_24x8.hpp +++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_half_24x8.hpp @@ -23,12 +23,12 @@ */ #pragma once -#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) +// AArch64 only, and either the FP16_KERNELS option set or the target explicitly supports FP16 vectors. +#if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)) -template <> +template<> inline void MergeResults<24, 8>(__fp16 *out, const __fp16 *in, const int ldout, const int y0, const int ymax, - const int x0, const int xmax, const __fp16 alpha, const __fp16 beta) -{ + const int x0, const int xmax, const __fp16 alpha, const __fp16 beta) { const __fp16 *inptr = in; prefetch_6x(inptr); prefetch_6x(inptr + 48); @@ -36,8 +36,7 @@ inline void MergeResults<24, 8>(__fp16 *out, const __fp16 *in, const int ldout, float16x8_t va = vdupq_n_f16(alpha); float16x8_t vb = vdupq_n_f16(beta); - for(int y = y0; y < ymax; y += 8) - { + for (int y=y0; y(__fp16 *out, const __fp16 *in, const int ldout, prefetch_2x(outptr6); prefetch_2x(outptr7); - for(int i = x0; i < xmax; i += 24) - { + for (int i=x0; i= ymax) - { - switch((y + 7) - ymax) - { + if ((y+7) >= ymax) { + switch ((y + 7) - ymax) { case 6: outptr1 = dummyres; case 5: @@ -85,149 +81,277 @@ inline void MergeResults<24, 8>(__fp16 *out, const __fp16 *in, const int ldout, default: UNREACHABLE("Impossible."); + } } - /* For ragged X, manually copy over the valid results. */ - if((i + 23) >= xmax) - { - for(int xi = 0; xi < 24; xi++) - { - if((i + xi) < xmax) - { - *outptr0 = (alpha * inptr[xi]) + (*outptr0 * beta); - outptr0++; - *outptr1 = (alpha * inptr[xi + 24]) + (*outptr1 * beta); - outptr1++; - *outptr2 = (alpha * inptr[xi + 48]) + (*outptr2 * beta); - outptr2++; - *outptr3 = (alpha * inptr[xi + 72]) + (*outptr3 * beta); - outptr3++; - *outptr4 = (alpha * inptr[xi + 96]) + (*outptr4 * beta); - outptr4++; - *outptr5 = (alpha * inptr[xi + 120]) + (*outptr5 * beta); - outptr5++; - *outptr6 = (alpha * inptr[xi + 144]) + (*outptr6 * beta); - outptr6++; - *outptr7 = (alpha * inptr[xi + 168]) + (*outptr7 * beta); - outptr7++; + if (beta == (__fp16)0.0f) { + /* If beta===0, don't read the output. */ + + /* For ragged X, manually copy over the valid results. */ + if ((i+23) >= xmax) { + for (int xi=0; xi<24; xi++) { + if ((i+xi) < xmax) { + *outptr0 = (alpha * inptr[xi]); + outptr0++; + *outptr1 = (alpha * inptr[xi + 24]); + outptr1++; + *outptr2 = (alpha * inptr[xi + 48]); + outptr2++; + *outptr3 = (alpha * inptr[xi + 72]); + outptr3++; + *outptr4 = (alpha * inptr[xi + 96]); + outptr4++; + *outptr5 = (alpha * inptr[xi + 120]); + outptr5++; + *outptr6 = (alpha * inptr[xi + 144]); + outptr6++; + *outptr7 = (alpha * inptr[xi + 168]); + outptr7++; + } } + inptr += 192; + } else { + /* Optimized routine to copy an entire block */ + __asm __volatile ( +#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + ".arch armv8.2-a+fp16\n" +#endif + // Rows 0-1 + ASM_PREFETCH("[%[inptr], #768]") + "LDP q0, q1, [%[inptr]]\n" + "LDP q2, q3, [%[inptr], #32]\n" + "LDP q4, q5, [%[inptr], #64]\n" + "FMUL v16.8h, v0.8h, %[va].8h\n" + ASM_PREFETCH("[%[inptr], #832]") + "FMUL v17.8h, v1.8h, %[va].8h\n" + "STP q16, q17, [%[outptr0]], #32\n" + "FMUL v18.8h, v2.8h, %[va].8h\n" + "STR q18, [%[outptr0]], #16\n" + "FMUL v19.8h, v3.8h, %[va].8h\n" + ASM_PREFETCH("[%[inptr], #896]") + "FMUL v20.8h, v4.8h, %[va].8h\n" + "STP q19, q20, [%[outptr1]], #32\n" + "FMUL v21.8h, v5.8h, %[va].8h\n" + "STR q21, [%[outptr1]], #16\n" + ASM_PREFETCH("[%[inptr], #960]") + + // Rows 2-3 + ASM_PREFETCH("[%[inptr], #1024]") + "LDP q0, q1, [%[inptr], #96]\n" + "LDP q2, q3, [%[inptr], #128]\n" + "LDP q4, q5, [%[inptr], #160]\n" + "FMUL v16.8h, v0.8h, %[va].8h\n" + ASM_PREFETCH("[%[inptr], #1088]") + "FMUL v17.8h, v1.8h, %[va].8h\n" + "STP q16, q17, [%[outptr2]], #32\n" + "FMUL v18.8h, v2.8h, %[va].8h\n" + "STR q18, [%[outptr2]], #16\n" + "FMUL v19.8h, v3.8h, %[va].8h\n" + ASM_PREFETCH("[%[outptr0], #80]") + "FMUL v20.8h, v4.8h, %[va].8h\n" + "STP q19, q20, [%[outptr3]], #32\n" + "FMUL v21.8h, v5.8h, %[va].8h\n" + "STR q21, [%[outptr3]], #16\n" + ASM_PREFETCH("[%[outptr1], #80]") + + // Rows 4-5 + ASM_PREFETCH("[%[outptr2], #80]") + "LDP q0, q1, [%[inptr], #192]\n" + "LDP q2, q3, [%[inptr], #224]\n" + "LDP q4, q5, [%[inptr], #256]\n" + "FMUL v16.8h, v0.8h, %[va].8h\n" + ASM_PREFETCH("[%[outptr3], #80]") + "FMUL v17.8h, v1.8h, %[va].8h\n" + "STP q16, q17, [%[outptr4]], #32\n" + "FMUL v18.8h, v2.8h, %[va].8h\n" + "STR q18, [%[outptr4]], #16\n" + "FMUL v19.8h, v3.8h, %[va].8h\n" + ASM_PREFETCH("[%[outptr4], #80]") + "FMUL v20.8h, v4.8h, %[va].8h\n" + "STP q19, q20, [%[outptr5]], #32\n" + "FMUL v21.8h, v5.8h, %[va].8h\n" + "STR q21, [%[outptr5]], #16\n" + + // Rows 6-7 + ASM_PREFETCH("[%[outptr5], #80]") + "LDP q0, q1, [%[inptr], #288]\n" + "LDP q2, q3, [%[inptr], #320]\n" + "LDP q4, q5, [%[inptr], #352]\n" + "FMUL v16.8h, v0.8h, %[va].8h\n" + ASM_PREFETCH("[%[outptr6], #128]") + "FMUL v17.8h, v1.8h, %[va].8h\n" + "STP q16, q17, [%[outptr6]], #32\n" + "FMUL v18.8h, v2.8h, %[va].8h\n" + "STR q18, [%[outptr6]], #16\n" + "FMUL v19.8h, v3.8h, %[va].8h\n" + ASM_PREFETCH("[%[outptr7], #128]") + "FMUL v20.8h, v4.8h, %[va].8h\n" + "STP q19, q20, [%[outptr7]], #32\n" + "FMUL v21.8h, v5.8h, %[va].8h\n" + "STR q21, [%[outptr7]], #16\n" + "ADD %[inptr], %[inptr], #384\n" + : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3), + [outptr4] "+r" (outptr4), [outptr5] "+r" (outptr5), [outptr6] "+r" (outptr6), [outptr7] "+r" (outptr7), + [inptr] "+r" (inptr) + : [va] "w" (va), [vb] "w" (vb) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v16", "v17", "v18", "v19", "v20", "v21" + ); + } + } else { + /* For ragged X, manually copy over the valid results. */ + if ((i+23) >= xmax) { + for (int xi=0; xi<24; xi++) { + if ((i+xi) < xmax) { + *outptr0 = (alpha * inptr[xi]) + (*outptr0 * beta); + outptr0++; + *outptr1 = (alpha * inptr[xi + 24]) + (*outptr1 * beta); + outptr1++; + *outptr2 = (alpha * inptr[xi + 48]) + (*outptr2 * beta); + outptr2++; + *outptr3 = (alpha * inptr[xi + 72]) + (*outptr3 * beta); + outptr3++; + *outptr4 = (alpha * inptr[xi + 96]) + (*outptr4 * beta); + outptr4++; + *outptr5 = (alpha * inptr[xi + 120]) + (*outptr5 * beta); + outptr5++; + *outptr6 = (alpha * inptr[xi + 144]) + (*outptr6 * beta); + outptr6++; + *outptr7 = (alpha * inptr[xi + 168]) + (*outptr7 * beta); + outptr7++; + } + } + inptr += 192; + } else { + /* Optimized routine to copy an entire block */ + __asm __volatile ( +#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + ".arch armv8.2-a+fp16\n" +#endif + // Rows 0-1 + "LDP q16, q17, [%[outptr0]]\n" + "FMUL v16.8h, v16.8h, %[vb].8h\n" + "LDR q18, [%[outptr0], #32]\n" + "FMUL v17.8h, v17.8h, %[vb].8h\n" + "LDP q19, q20, [%[outptr1]]\n" + "FMUL v18.8h, v18.8h, %[vb].8h\n" + ASM_PREFETCH("[%[inptr], #768]") + "LDR q21, [%[outptr1], #32]\n" + "FMUL v19.8h, v19.8h, %[vb].8h\n" + "LDP q0, q1, [%[inptr]]\n" + "FMUL v20.8h, v20.8h, %[vb].8h\n" + "LDP q2, q3, [%[inptr], #32]\n" + "FMUL v21.8h, v21.8h, %[vb].8h\n" + "LDP q4, q5, [%[inptr], #64]\n" + "FMLA v16.8h, v0.8h, %[va].8h\n" + ASM_PREFETCH("[%[inptr], #832]") + "FMLA v17.8h, v1.8h, %[va].8h\n" + "STP q16, q17, [%[outptr0]], #32\n" + "FMLA v18.8h, v2.8h, %[va].8h\n" + "STR q18, [%[outptr0]], #16\n" + "FMLA v19.8h, v3.8h, %[va].8h\n" + ASM_PREFETCH("[%[inptr], #896]") + "FMLA v20.8h, v4.8h, %[va].8h\n" + "STP q19, q20, [%[outptr1]], #32\n" + "FMLA v21.8h, v5.8h, %[va].8h\n" + "STR q21, [%[outptr1]], #16\n" + ASM_PREFETCH("[%[inptr], #960]") + + // Rows 2-3 + "LDP q16, q17, [%[outptr2]]\n" + "FMUL v16.8h, v16.8h, %[vb].8h\n" + "LDR q18, [%[outptr2], #32]\n" + "FMUL v17.8h, v17.8h, %[vb].8h\n" + "LDP q19, q20, [%[outptr3]]\n" + "FMUL v18.8h, v18.8h, %[vb].8h\n" + ASM_PREFETCH("[%[inptr], #1024]") + "LDR q21, [%[outptr3], #32]\n" + "FMUL v19.8h, v19.8h, %[vb].8h\n" + "LDP q0, q1, [%[inptr], #96]\n" + "FMUL v20.8h, v20.8h, %[vb].8h\n" + "LDP q2, q3, [%[inptr], #128]\n" + "FMUL v21.8h, v21.8h, %[vb].8h\n" + "LDP q4, q5, [%[inptr], #160]\n" + "FMLA v16.8h, v0.8h, %[va].8h\n" + ASM_PREFETCH("[%[inptr], #1088]") + "FMLA v17.8h, v1.8h, %[va].8h\n" + "STP q16, q17, [%[outptr2]], #32\n" + "FMLA v18.8h, v2.8h, %[va].8h\n" + "STR q18, [%[outptr2]], #16\n" + "FMLA v19.8h, v3.8h, %[va].8h\n" + ASM_PREFETCH("[%[outptr0], #80]") + "FMLA v20.8h, v4.8h, %[va].8h\n" + "STP q19, q20, [%[outptr3]], #32\n" + "FMLA v21.8h, v5.8h, %[va].8h\n" + "STR q21, [%[outptr3]], #16\n" + ASM_PREFETCH("[%[outptr1], #80]") + + // Rows 4-5 + "LDP q16, q17, [%[outptr4]]\n" + "FMUL v16.8h, v16.8h, %[vb].8h\n" + "LDR q18, [%[outptr4], #32]\n" + "FMUL v17.8h, v17.8h, %[vb].8h\n" + "LDP q19, q20, [%[outptr5]]\n" + "FMUL v18.8h, v18.8h, %[vb].8h\n" + ASM_PREFETCH("[%[outptr2], #80]") + "LDR q21, [%[outptr5], #32]\n" + "FMUL v19.8h, v19.8h, %[vb].8h\n" + "LDP q0, q1, [%[inptr], #192]\n" + "FMUL v20.8h, v20.8h, %[vb].8h\n" + "LDP q2, q3, [%[inptr], #224]\n" + "FMUL v21.8h, v21.8h, %[vb].8h\n" + "LDP q4, q5, [%[inptr], #256]\n" + "FMLA v16.8h, v0.8h, %[va].8h\n" + ASM_PREFETCH("[%[outptr3], #80]") + "FMLA v17.8h, v1.8h, %[va].8h\n" + "STP q16, q17, [%[outptr4]], #32\n" + "FMLA v18.8h, v2.8h, %[va].8h\n" + "STR q18, [%[outptr4]], #16\n" + "FMLA v19.8h, v3.8h, %[va].8h\n" + ASM_PREFETCH("[%[outptr4], #80]") + "FMLA v20.8h, v4.8h, %[va].8h\n" + "STP q19, q20, [%[outptr5]], #32\n" + "FMLA v21.8h, v5.8h, %[va].8h\n" + "STR q21, [%[outptr5]], #16\n" + + // Rows 6-7 + "LDP q16, q17, [%[outptr6]]\n" + "FMUL v16.8h, v16.8h, %[vb].8h\n" + "LDR q18, [%[outptr6], #32]\n" + "FMUL v17.8h, v17.8h, %[vb].8h\n" + "LDP q19, q20, [%[outptr7]]\n" + ASM_PREFETCH("[%[outptr5], #80]") + "FMUL v18.8h, v18.8h, %[vb].8h\n" + "LDR q21, [%[outptr7], #32]\n" + "FMUL v19.8h, v19.8h, %[vb].8h\n" + "LDP q0, q1, [%[inptr], #288]\n" + "FMUL v20.8h, v20.8h, %[vb].8h\n" + "LDP q2, q3, [%[inptr], #320]\n" + "FMUL v21.8h, v21.8h, %[vb].8h\n" + "LDP q4, q5, [%[inptr], #352]\n" + "FMLA v16.8h, v0.8h, %[va].8h\n" + ASM_PREFETCH("[%[outptr6], #128]") + "FMLA v17.8h, v1.8h, %[va].8h\n" + "STP q16, q17, [%[outptr6]], #32\n" + "FMLA v18.8h, v2.8h, %[va].8h\n" + "STR q18, [%[outptr6]], #16\n" + "FMLA v19.8h, v3.8h, %[va].8h\n" + ASM_PREFETCH("[%[outptr7], #128]") + "FMLA v20.8h, v4.8h, %[va].8h\n" + "STP q19, q20, [%[outptr7]], #32\n" + "FMLA v21.8h, v5.8h, %[va].8h\n" + "STR q21, [%[outptr7]], #16\n" + "ADD %[inptr], %[inptr], #384\n" + : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3), + [outptr4] "+r" (outptr4), [outptr5] "+r" (outptr5), [outptr6] "+r" (outptr6), [outptr7] "+r" (outptr7), + [inptr] "+r" (inptr) + : [va] "w" (va), [vb] "w" (vb) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v16", "v17", "v18", "v19", "v20", "v21" + ); } - inptr += 192; - } - else - { - /* Optimized routine to copy an entire block */ - __asm __volatile( - ".arch armv8.2-a+fp16\n" - // Rows 0-1 - "LDP q16, q17, [%[outptr0]]\n" - "FMUL v16.8h, v16.8h, %[vb].8h\n" - "LDR q18, [%[outptr0], #32]\n" - "FMUL v17.8h, v17.8h, %[vb].8h\n" - "LDP q19, q20, [%[outptr1]]\n" - "FMUL v18.8h, v18.8h, %[vb].8h\n" ASM_PREFETCH("[%[inptr], #768]") - "LDR q21, [%[outptr1], #32]\n" - "FMUL v19.8h, v19.8h, %[vb].8h\n" - "LDP q0, q1, [%[inptr]]\n" - "FMUL v20.8h, v20.8h, %[vb].8h\n" - "LDP q2, q3, [%[inptr], #32]\n" - "FMUL v21.8h, v21.8h, %[vb].8h\n" - "LDP q4, q5, [%[inptr], #64]\n" - "FMLA v16.8h, v0.8h, %[va].8h\n" ASM_PREFETCH("[%[inptr], #832]") - "FMLA v17.8h, v1.8h, %[va].8h\n" - "STP q16, q17, [%[outptr0]], #32\n" - "FMLA v18.8h, v2.8h, %[va].8h\n" - "STR q18, [%[outptr0]], #16\n" - "FMLA v19.8h, v3.8h, %[va].8h\n" ASM_PREFETCH("[%[inptr], #896]") - "FMLA v20.8h, v4.8h, %[va].8h\n" - "STP q19, q20, [%[outptr1]], #32\n" - "FMLA v21.8h, v5.8h, %[va].8h\n" - "STR q21, [%[outptr1]], #16\n" ASM_PREFETCH("[%[inptr], #960]") - - // Rows 2-3 - "LDP q16, q17, [%[outptr2]]\n" - "FMUL v16.8h, v16.8h, %[vb].8h\n" - "LDR q18, [%[outptr2], #32]\n" - "FMUL v17.8h, v17.8h, %[vb].8h\n" - "LDP q19, q20, [%[outptr3]]\n" - "FMUL v18.8h, v18.8h, %[vb].8h\n" ASM_PREFETCH("[%[inptr], #1024]") - "LDR q21, [%[outptr3], #32]\n" - "FMUL v19.8h, v19.8h, %[vb].8h\n" - "LDP q0, q1, [%[inptr], #96]\n" - "FMUL v20.8h, v20.8h, %[vb].8h\n" - "LDP q2, q3, [%[inptr], #128]\n" - "FMUL v21.8h, v21.8h, %[vb].8h\n" - "LDP q4, q5, [%[inptr], #160]\n" - "FMLA v16.8h, v0.8h, %[va].8h\n" ASM_PREFETCH("[%[inptr], #1088]") - "FMLA v17.8h, v1.8h, %[va].8h\n" - "STP q16, q17, [%[outptr2]], #32\n" - "FMLA v18.8h, v2.8h, %[va].8h\n" - "STR q18, [%[outptr2]], #16\n" - "FMLA v19.8h, v3.8h, %[va].8h\n" ASM_PREFETCH("[%[outptr0], #80]") - "FMLA v20.8h, v4.8h, %[va].8h\n" - "STP q19, q20, [%[outptr3]], #32\n" - "FMLA v21.8h, v5.8h, %[va].8h\n" - "STR q21, [%[outptr3]], #16\n" ASM_PREFETCH("[%[outptr1], #80]") - - // Rows 4-5 - "LDP q16, q17, [%[outptr4]]\n" - "FMUL v16.8h, v16.8h, %[vb].8h\n" - "LDR q18, [%[outptr4], #32]\n" - "FMUL v17.8h, v17.8h, %[vb].8h\n" - "LDP q19, q20, [%[outptr5]]\n" - "FMUL v18.8h, v18.8h, %[vb].8h\n" ASM_PREFETCH("[%[outptr2], #80]") - "LDR q21, [%[outptr5], #32]\n" - "FMUL v19.8h, v19.8h, %[vb].8h\n" - "LDP q0, q1, [%[inptr], #192]\n" - "FMUL v20.8h, v20.8h, %[vb].8h\n" - "LDP q2, q3, [%[inptr], #224]\n" - "FMUL v21.8h, v21.8h, %[vb].8h\n" - "LDP q4, q5, [%[inptr], #256]\n" - "FMLA v16.8h, v0.8h, %[va].8h\n" ASM_PREFETCH("[%[outptr3], #80]") - "FMLA v17.8h, v1.8h, %[va].8h\n" - "STP q16, q17, [%[outptr4]], #32\n" - "FMLA v18.8h, v2.8h, %[va].8h\n" - "STR q18, [%[outptr4]], #16\n" - "FMLA v19.8h, v3.8h, %[va].8h\n" ASM_PREFETCH("[%[outptr4], #80]") - "FMLA v20.8h, v4.8h, %[va].8h\n" - "STP q19, q20, [%[outptr5]], #32\n" - "FMLA v21.8h, v5.8h, %[va].8h\n" - "STR q21, [%[outptr5]], #16\n" - - // Rows 6-7 - "LDP q16, q17, [%[outptr6]]\n" - "FMUL v16.8h, v16.8h, %[vb].8h\n" - "LDR q18, [%[outptr6], #32]\n" - "FMUL v17.8h, v17.8h, %[vb].8h\n" - "LDP q19, q20, [%[outptr7]]\n" ASM_PREFETCH("[%[outptr5], #80]") - "FMUL v18.8h, v18.8h, %[vb].8h\n" - "LDR q21, [%[outptr7], #32]\n" - "FMUL v19.8h, v19.8h, %[vb].8h\n" - "LDP q0, q1, [%[inptr], #288]\n" - "FMUL v20.8h, v20.8h, %[vb].8h\n" - "LDP q2, q3, [%[inptr], #320]\n" - "FMUL v21.8h, v21.8h, %[vb].8h\n" - "LDP q4, q5, [%[inptr], #352]\n" - "FMLA v16.8h, v0.8h, %[va].8h\n" ASM_PREFETCH("[%[outptr6], #128]") - "FMLA v17.8h, v1.8h, %[va].8h\n" - "STP q16, q17, [%[outptr6]], #32\n" - "FMLA v18.8h, v2.8h, %[va].8h\n" - "STR q18, [%[outptr6]], #16\n" - "FMLA v19.8h, v3.8h, %[va].8h\n" ASM_PREFETCH("[%[outptr7], #128]") - "FMLA v20.8h, v4.8h, %[va].8h\n" - "STP q19, q20, [%[outptr7]], #32\n" - "FMLA v21.8h, v5.8h, %[va].8h\n" - "STR q21, [%[outptr7]], #16\n" - "ADD %[inptr], %[inptr], #384\n" - : [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3), - [outptr4] "+r"(outptr4), [outptr5] "+r"(outptr5), [outptr6] "+r"(outptr6), [outptr7] "+r"(outptr7), - [inptr] "+r"(inptr) - : [va] "w"(va), [vb] "w"(vb) - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v16", "v17", "v18", "v19", "v20", "v21"); } } } } -#endif // __aarch64__ && __ARM_FEATURE_FP16_SCALAR_ARITHMETIC +#endif // __aarch64__ && (FP16_KERNELS || __ARM_FEATURE_FP16_VECTOR_ARITHMETIC) diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp index 79dd1f07e3..1a51505a25 100644 --- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp +++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp @@ -25,18 +25,16 @@ #ifdef __aarch64__ -template <> -inline void MergeResults<12, 8>(int32_t *out, const int32_t *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const int32_t alpha, const int32_t beta) -{ +template<> +inline void MergeResults<12, 8>(int32_t *out, const int32_t *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const int32_t alpha, const int32_t beta) { const int32_t *inptr = in; prefetch_6x(inptr); prefetch_6x(inptr + 96); int32x4_t alpha_value = vdupq_n_s32(alpha); - int32x4_t beta_value = vdupq_n_s32(beta); + int32x4_t beta_value = vdupq_n_s32(beta); - for(int y = y0; y < ymax; y += 8) - { + for (int y=y0; y(int32_t *out, const int32_t *in, const int ldout prefetch_2x(outptr6); prefetch_2x(outptr7); - for(int i = x0; i < xmax; i += 12) - { + for (int i=x0; i= ymax) - { - switch((y + 7) - ymax) - { + if ((y+7) >= ymax) { + switch ((y + 7) - ymax) { case 6: outptr1 = dummyres; case 5: @@ -88,12 +83,9 @@ inline void MergeResults<12, 8>(int32_t *out, const int32_t *in, const int ldout } /* For ragged X, manually copy over the valid results. */ - if((i + 11) >= xmax) - { - for(int xi = 0; xi < 12; xi++) - { - if((i + xi) < xmax) - { + if ((i+11) >= xmax) { + for (int xi=0; xi<12; xi++) { + if ((i+xi) < xmax) { *outptr0 = (alpha * inptr[xi]) + (*outptr0 * beta); outptr0++; *outptr1 = (alpha * inptr[xi + 12]) + (*outptr1 * beta); @@ -113,177 +105,175 @@ inline void MergeResults<12, 8>(int32_t *out, const int32_t *in, const int ldout } } inptr += 96; - } - else - { + } else { /* Optimized routine to copy an entire block */ - __asm __volatile( - // Row 0 - ASM_PREFETCH("[%x[outptr1], #192]") - "ldr q3, [%x[outptr0]]\n" - "ldr q4, [%x[outptr0], #0x10]\n" - "ldr q5, [%x[outptr0], #0x20]\n" - "mul v3.4s, v3.4s, %[beta_value].4s\n" - "ldr q6, [%x[inptr]]\n" - "mul v4.4s, v4.4s, %[beta_value].4s\n" - "ldr q7, [%x[inptr], #0x10]\n" - "mul v5.4s, v5.4s, %[beta_value].4s\n" - "ldr q8, [%x[inptr], #0x20]\n" - "mla v3.4s, v6.4s, %[alpha_value].4s\n" - "ldr q0, [%x[outptr1]]\n" - "mla v4.4s, v7.4s, %[alpha_value].4s\n" - "ldr q1, [%x[outptr1], #0x10]\n" - "mla v5.4s, v8.4s, %[alpha_value].4s\n" - "ldr q2, [%x[outptr1], #0x20]\n" + __asm __volatile ( + // Row 0 + ASM_PREFETCH("[%x[outptr1], #192]") + "ldr q3, [%x[outptr0]]\n" + "ldr q4, [%x[outptr0], #0x10]\n" + "ldr q5, [%x[outptr0], #0x20]\n" + "mul v3.4s, v3.4s, %[beta_value].4s\n" + "ldr q6, [%x[inptr]]\n" + "mul v4.4s, v4.4s, %[beta_value].4s\n" + "ldr q7, [%x[inptr], #0x10]\n" + "mul v5.4s, v5.4s, %[beta_value].4s\n" + "ldr q8, [%x[inptr], #0x20]\n" + "mla v3.4s, v6.4s, %[alpha_value].4s\n" + "ldr q0, [%x[outptr1]]\n" + "mla v4.4s, v7.4s, %[alpha_value].4s\n" + "ldr q1, [%x[outptr1], #0x10]\n" + "mla v5.4s, v8.4s, %[alpha_value].4s\n" + "ldr q2, [%x[outptr1], #0x20]\n" - // Row 1 - ASM_PREFETCH("[%x[outptr2], #192]") - "mul v0.4s, v0.4s, %[beta_value].4s\n" - "ldr q6, [%x[inptr], #0x30]\n" - "str q3, [%x[outptr0]], #0x10\n" - "mul v1.4s, v1.4s, %[beta_value].4s\n" - "ldr q7, [%x[inptr], #0x40]\n" - "str q4, [%x[outptr0]], #0x10\n" - "mul v2.4s, v2.4s, %[beta_value].4s\n" - "ldr q8, [%x[inptr], #0x50]\n" - "str q5, [%x[outptr0]], #0x10\n" - "mla v0.4s, v6.4s, %[alpha_value].4s\n" - "ldr q3, [%x[outptr2]]\n" - "mla v1.4s, v7.4s, %[alpha_value].4s\n" - "ldr q4, [%x[outptr2], #0x10]\n" - "mla v2.4s, v8.4s, %[alpha_value].4s\n" - "ldr q5, [%x[outptr2], #0x20]\n" + // Row 1 + ASM_PREFETCH("[%x[outptr2], #192]") + "mul v0.4s, v0.4s, %[beta_value].4s\n" + "ldr q6, [%x[inptr], #0x30]\n" + "str q3, [%x[outptr0]], #0x10\n" + "mul v1.4s, v1.4s, %[beta_value].4s\n" + "ldr q7, [%x[inptr], #0x40]\n" + "str q4, [%x[outptr0]], #0x10\n" + "mul v2.4s, v2.4s, %[beta_value].4s\n" + "ldr q8, [%x[inptr], #0x50]\n" + "str q5, [%x[outptr0]], #0x10\n" + "mla v0.4s, v6.4s, %[alpha_value].4s\n" + "ldr q3, [%x[outptr2]]\n" + "mla v1.4s, v7.4s, %[alpha_value].4s\n" + "ldr q4, [%x[outptr2], #0x10]\n" + "mla v2.4s, v8.4s, %[alpha_value].4s\n" + "ldr q5, [%x[outptr2], #0x20]\n" - // Row 2 - ASM_PREFETCH("[%x[outptr3], #192]") - "mul v3.4s, v3.4s, %[beta_value].4s\n" - "ldr q6, [%x[inptr], #0x60]\n" - "str q0, [%x[outptr1]], #0x10\n" - "mul v4.4s, v4.4s, %[beta_value].4s\n" - "ldr q7, [%x[inptr], #0x70]\n" - "str q1, [%x[outptr1]], #0x10\n" - "mul v5.4s, v5.4s, %[beta_value].4s\n" - "ldr q8, [%x[inptr], #0x80]\n" - "str q2, [%x[outptr1]], #0x10\n" - "mla v3.4s, v6.4s, %[alpha_value].4s\n" - "ldr q0, [%x[outptr3]]\n" - "mla v4.4s, v7.4s, %[alpha_value].4s\n" - "ldr q1, [%x[outptr3], #0x10]\n" - "mla v5.4s, v8.4s, %[alpha_value].4s\n" - "ldr q2, [%x[outptr3], #0x20]\n" + // Row 2 + ASM_PREFETCH("[%x[outptr3], #192]") + "mul v3.4s, v3.4s, %[beta_value].4s\n" + "ldr q6, [%x[inptr], #0x60]\n" + "str q0, [%x[outptr1]], #0x10\n" + "mul v4.4s, v4.4s, %[beta_value].4s\n" + "ldr q7, [%x[inptr], #0x70]\n" + "str q1, [%x[outptr1]], #0x10\n" + "mul v5.4s, v5.4s, %[beta_value].4s\n" + "ldr q8, [%x[inptr], #0x80]\n" + "str q2, [%x[outptr1]], #0x10\n" + "mla v3.4s, v6.4s, %[alpha_value].4s\n" + "ldr q0, [%x[outptr3]]\n" + "mla v4.4s, v7.4s, %[alpha_value].4s\n" + "ldr q1, [%x[outptr3], #0x10]\n" + "mla v5.4s, v8.4s, %[alpha_value].4s\n" + "ldr q2, [%x[outptr3], #0x20]\n" - // Row 3 - ASM_PREFETCH("[%x[outptr4], #192]") - "mul v0.4s, v0.4s, %[beta_value].4s\n" - "ldr q6, [%x[inptr], #0x90]\n" - "str q3, [%x[outptr2]], #0x10\n" - "mul v1.4s, v1.4s, %[beta_value].4s\n" - "ldr q7, [%x[inptr], #0xa0]\n" - "str q4, [%x[outptr2]], #0x10\n" - "mul v2.4s, v2.4s, %[beta_value].4s\n" - "ldr q8, [%x[inptr], #0xb0]\n" - "str q5, [%x[outptr2]], #0x10\n" - "mla v0.4s, v6.4s, %[alpha_value].4s\n" - "ldr q3, [%x[outptr4]]\n" - "mla v1.4s, v7.4s, %[alpha_value].4s\n" - "ldr q4, [%x[outptr4], #0x10]\n" - "mla v2.4s, v8.4s, %[alpha_value].4s\n" - "ldr q5, [%x[outptr4], #0x20]\n" + // Row 3 + ASM_PREFETCH("[%x[outptr4], #192]") + "mul v0.4s, v0.4s, %[beta_value].4s\n" + "ldr q6, [%x[inptr], #0x90]\n" + "str q3, [%x[outptr2]], #0x10\n" + "mul v1.4s, v1.4s, %[beta_value].4s\n" + "ldr q7, [%x[inptr], #0xa0]\n" + "str q4, [%x[outptr2]], #0x10\n" + "mul v2.4s, v2.4s, %[beta_value].4s\n" + "ldr q8, [%x[inptr], #0xb0]\n" + "str q5, [%x[outptr2]], #0x10\n" + "mla v0.4s, v6.4s, %[alpha_value].4s\n" + "ldr q3, [%x[outptr4]]\n" + "mla v1.4s, v7.4s, %[alpha_value].4s\n" + "ldr q4, [%x[outptr4], #0x10]\n" + "mla v2.4s, v8.4s, %[alpha_value].4s\n" + "ldr q5, [%x[outptr4], #0x20]\n" - // Row 4 - ASM_PREFETCH("[%x[outptr5], #192]") - "mul v3.4s, v3.4s, %[beta_value].4s\n" - "ldr q6, [%x[inptr], #0xc0]\n" - "str q0, [%x[outptr3]], #0x10\n" - "mul v4.4s, v4.4s, %[beta_value].4s\n" - "ldr q7, [%x[inptr], #0xd0]\n" - "str q1, [%x[outptr3]], #0x10\n" - "mul v5.4s, v5.4s, %[beta_value].4s\n" - "ldr q8, [%x[inptr], #0xe0]\n" - "str q2, [%x[outptr3]], #0x10\n" - "mla v3.4s, v6.4s, %[alpha_value].4s\n" - "ldr q0, [%x[outptr5]]\n" - "mla v4.4s, v7.4s, %[alpha_value].4s\n" - "ldr q1, [%x[outptr5], #0x10]\n" - "mla v5.4s, v8.4s, %[alpha_value].4s\n" - "ldr q2, [%x[outptr5], #0x20]\n" + // Row 4 + ASM_PREFETCH("[%x[outptr5], #192]") + "mul v3.4s, v3.4s, %[beta_value].4s\n" + "ldr q6, [%x[inptr], #0xc0]\n" + "str q0, [%x[outptr3]], #0x10\n" + "mul v4.4s, v4.4s, %[beta_value].4s\n" + "ldr q7, [%x[inptr], #0xd0]\n" + "str q1, [%x[outptr3]], #0x10\n" + "mul v5.4s, v5.4s, %[beta_value].4s\n" + "ldr q8, [%x[inptr], #0xe0]\n" + "str q2, [%x[outptr3]], #0x10\n" + "mla v3.4s, v6.4s, %[alpha_value].4s\n" + "ldr q0, [%x[outptr5]]\n" + "mla v4.4s, v7.4s, %[alpha_value].4s\n" + "ldr q1, [%x[outptr5], #0x10]\n" + "mla v5.4s, v8.4s, %[alpha_value].4s\n" + "ldr q2, [%x[outptr5], #0x20]\n" - // Row 5 - ASM_PREFETCH("[%x[outptr6], #192]") - "mul v0.4s, v0.4s, %[beta_value].4s\n" - "ldr q6, [%x[inptr], #0xf0]\n" - "str q3, [%x[outptr4]], #0x10\n" - "mul v1.4s, v1.4s, %[beta_value].4s\n" - "ldr q7, [%x[inptr], #0x100]\n" - "str q4, [%x[outptr4]], #0x10\n" - "mul v2.4s, v2.4s, %[beta_value].4s\n" - "ldr q8, [%x[inptr], #0x110]\n" - "str q5, [%x[outptr4]], #0x10\n" - "mla v0.4s, v6.4s, %[alpha_value].4s\n" - "ldr q3, [%x[outptr6]]\n" - "mla v1.4s, v7.4s, %[alpha_value].4s\n" - "ldr q4, [%x[outptr6], #0x10]\n" - "mla v2.4s, v8.4s, %[alpha_value].4s\n" - "ldr q5, [%x[outptr6], #0x20]\n" + // Row 5 + ASM_PREFETCH("[%x[outptr6], #192]") + "mul v0.4s, v0.4s, %[beta_value].4s\n" + "ldr q6, [%x[inptr], #0xf0]\n" + "str q3, [%x[outptr4]], #0x10\n" + "mul v1.4s, v1.4s, %[beta_value].4s\n" + "ldr q7, [%x[inptr], #0x100]\n" + "str q4, [%x[outptr4]], #0x10\n" + "mul v2.4s, v2.4s, %[beta_value].4s\n" + "ldr q8, [%x[inptr], #0x110]\n" + "str q5, [%x[outptr4]], #0x10\n" + "mla v0.4s, v6.4s, %[alpha_value].4s\n" + "ldr q3, [%x[outptr6]]\n" + "mla v1.4s, v7.4s, %[alpha_value].4s\n" + "ldr q4, [%x[outptr6], #0x10]\n" + "mla v2.4s, v8.4s, %[alpha_value].4s\n" + "ldr q5, [%x[outptr6], #0x20]\n" - // Row 6 - ASM_PREFETCH("[%x[outptr7], #192]") - "mul v3.4s, v3.4s, %[beta_value].4s\n" - "ldr q6, [%x[inptr], #0x120]\n" - "str q0, [%x[outptr5]], #0x10\n" - "mul v4.4s, v4.4s, %[beta_value].4s\n" - "ldr q7, [%x[inptr], #0x130]\n" - "str q1, [%x[outptr5]], #0x10\n" - "mul v5.4s, v5.4s, %[beta_value].4s\n" - "ldr q8, [%x[inptr], #0x140]\n" - "str q2, [%x[outptr5]], #0x10\n" - "mla v3.4s, v6.4s, %[alpha_value].4s\n" - "ldr q0, [%x[outptr7]]\n" - "mla v4.4s, v7.4s, %[alpha_value].4s\n" - "ldr q1, [%x[outptr7], #0x10]\n" - "mla v5.4s, v8.4s, %[alpha_value].4s\n" - "ldr q2, [%x[outptr7], #0x20]\n" + // Row 6 + ASM_PREFETCH("[%x[outptr7], #192]") + "mul v3.4s, v3.4s, %[beta_value].4s\n" + "ldr q6, [%x[inptr], #0x120]\n" + "str q0, [%x[outptr5]], #0x10\n" + "mul v4.4s, v4.4s, %[beta_value].4s\n" + "ldr q7, [%x[inptr], #0x130]\n" + "str q1, [%x[outptr5]], #0x10\n" + "mul v5.4s, v5.4s, %[beta_value].4s\n" + "ldr q8, [%x[inptr], #0x140]\n" + "str q2, [%x[outptr5]], #0x10\n" + "mla v3.4s, v6.4s, %[alpha_value].4s\n" + "ldr q0, [%x[outptr7]]\n" + "mla v4.4s, v7.4s, %[alpha_value].4s\n" + "ldr q1, [%x[outptr7], #0x10]\n" + "mla v5.4s, v8.4s, %[alpha_value].4s\n" + "ldr q2, [%x[outptr7], #0x20]\n" - // Row 7 - "mul v0.4s, v0.4s, %[beta_value].4s\n" - "ldr q6, [%x[inptr], #0x150]\n" - "str q3, [%x[outptr6]], #0x10\n" - "mul v1.4s, v1.4s, %[beta_value].4s\n" - "ldr q7, [%x[inptr], #0x160]\n" - "str q4, [%x[outptr6]], #0x10\n" - "mul v2.4s, v2.4s, %[beta_value].4s\n" - "ldr q8, [%x[inptr], #0x170]\n" - "str q5, [%x[outptr6]], #0x10\n" - "mla v0.4s, v6.4s, %[alpha_value].4s\n" - "mla v1.4s, v7.4s, %[alpha_value].4s\n" - "mla v2.4s, v8.4s, %[alpha_value].4s\n" - "str q0, [%x[outptr7]], #0x10\n" - "str q1, [%x[outptr7]], #0x10\n" - "str q2, [%x[outptr7]], #0x10\n" + // Row 7 + "mul v0.4s, v0.4s, %[beta_value].4s\n" + "ldr q6, [%x[inptr], #0x150]\n" + "str q3, [%x[outptr6]], #0x10\n" + "mul v1.4s, v1.4s, %[beta_value].4s\n" + "ldr q7, [%x[inptr], #0x160]\n" + "str q4, [%x[outptr6]], #0x10\n" + "mul v2.4s, v2.4s, %[beta_value].4s\n" + "ldr q8, [%x[inptr], #0x170]\n" + "str q5, [%x[outptr6]], #0x10\n" + "mla v0.4s, v6.4s, %[alpha_value].4s\n" + "mla v1.4s, v7.4s, %[alpha_value].4s\n" + "mla v2.4s, v8.4s, %[alpha_value].4s\n" + "str q0, [%x[outptr7]], #0x10\n" + "str q1, [%x[outptr7]], #0x10\n" + "str q2, [%x[outptr7]], #0x10\n" - "add %x[inptr], %x[inptr], #0x180\n" - : [outptr0] "+r"(outptr0), - [outptr1] "+r"(outptr1), - [outptr2] "+r"(outptr2), - [outptr3] "+r"(outptr3), - [outptr4] "+r"(outptr4), - [outptr5] "+r"(outptr5), - [outptr6] "+r"(outptr6), - [outptr7] "+r"(outptr7), - [inptr] "+r"(inptr) - : [alpha_value] "w"(alpha_value), - [beta_value] "w"(beta_value) - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8"); + "add %x[inptr], %x[inptr], #0x180\n" + : [outptr0] "+r" (outptr0), + [outptr1] "+r" (outptr1), + [outptr2] "+r" (outptr2), + [outptr3] "+r" (outptr3), + [outptr4] "+r" (outptr4), + [outptr5] "+r" (outptr5), + [outptr6] "+r" (outptr6), + [outptr7] "+r" (outptr7), + [inptr] "+r" (inptr) + : [alpha_value] "w" (alpha_value), + [beta_value] "w" (beta_value) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8" + ); } } } } -template <> -inline void MergeResults<12, 8>(uint32_t *out, const uint32_t *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const uint32_t alpha, const uint32_t beta) -{ - // Since the above code uses only MUL and MLA instructions discard the "unsignedness" and proceed safely. - MergeResults<12, 8>(reinterpret_cast(out), reinterpret_cast(in), ldout, y0, ymax, x0, xmax, static_cast(alpha), static_cast(beta)); +template<> +inline void MergeResults<12, 8>(uint32_t *out, const uint32_t *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const uint32_t alpha, const uint32_t beta) { + // Since the above code uses only MUL and MLA instructions discard the "unsignedness" and proceed safely. + MergeResults<12, 8>(reinterpret_cast(out), reinterpret_cast(in), ldout, y0, ymax, x0, xmax, static_cast(alpha), static_cast(beta)); } #endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/profiler.hpp b/src/core/NEON/kernels/arm_gemm/profiler.hpp index ada0c95e26..1b944c4ccd 100644 --- a/src/core/NEON/kernels/arm_gemm/profiler.hpp +++ b/src/core/NEON/kernels/arm_gemm/profiler.hpp @@ -31,75 +31,65 @@ #include #endif -namespace arm_gemm -{ +namespace arm_gemm { + #ifndef NO_MULTI_THREADING extern std::mutex report_mutex; #endif -class profiler -{ +class profiler { private: - static const int maxevents = 100000; - unsigned long times[maxevents] = {}; - unsigned long units[maxevents] = {}; - int events[maxevents] = {}; - int currentevent = 0; - int countfd = 0; - - class ScopedProfilerClass - { + static const int maxevents = 100000; + unsigned long times[maxevents] = { }; + unsigned long units[maxevents] = { }; + int events[maxevents] = { }; + int currentevent=0; + int countfd=0; + + class ScopedProfilerClass { private: profiler &_parent; - bool legal = false; + bool legal=false; public: - ScopedProfilerClass(profiler &prof, int i, unsigned long u) - : _parent(prof) - { - if(prof.currentevent == maxevents) + ScopedProfilerClass(profiler &prof, int i, unsigned long u) : _parent(prof) { + if (prof.currentevent==maxevents) return; - prof.events[prof.currentevent] = i; - prof.units[prof.currentevent] = u; - legal = true; + prof.events[prof.currentevent]=i; + prof.units[prof.currentevent]=u; + legal=true; start_counter(prof.countfd); } - ~ScopedProfilerClass() - { - if(!legal) - return; + ~ScopedProfilerClass() { + if (!legal) return; - long long cycs = stop_counter(_parent.countfd); + long long cycs = stop_counter(_parent.countfd); _parent.times[_parent.currentevent++] = cycs; } }; public: - profiler() - { - countfd = open_cycle_counter(); + profiler() { + countfd=open_cycle_counter(); } - ~profiler() - { + ~profiler() { close(countfd); - int tots[5]; + int tots[5]; unsigned long counts[5]; unsigned long tunits[5]; - const char *descs[] = { "Prepare A", "Prepare B", "Kernel", "Merge" }; + const char * descs[] = { "Prepare A", "Prepare B", "Kernel", "Merge" }; - for(int i = 1; i < 5; i++) - { - tots[i] = 0; + for (int i=1; i<5; i++) { + tots[i] = 0; counts[i] = 0; tunits[i] = 0; } - for(int i = 0; i < currentevent; i++) - { - // printf("%10s: %ld\n", descs[events[i]-1], times[i]); + for (int i=0; i - void operator()(int i, unsigned long u, T func) - { - if(currentevent == maxevents) - { + void operator() (int i, unsigned long u, T func) { + if (currentevent==maxevents) { func(); - } - else - { + } else { events[currentevent] = i; - units[currentevent] = u; + units[currentevent] = u; start_counter(countfd); func(); - long long cycs = stop_counter(countfd); + long long cycs = stop_counter(countfd); times[currentevent++] = cycs; } } - ScopedProfilerClass ScopedProfiler(int i, unsigned long u) - { + + ScopedProfilerClass ScopedProfiler(int i, unsigned long u) { return ScopedProfilerClass(*this, i, u); } }; diff --git a/src/core/NEON/kernels/arm_gemm/transform.hpp b/src/core/NEON/kernels/arm_gemm/transform.hpp index c80bb59941..35e61b05a4 100644 --- a/src/core/NEON/kernels/arm_gemm/transform.hpp +++ b/src/core/NEON/kernels/arm_gemm/transform.hpp @@ -35,63 +35,51 @@ * being a multiple of the block sizes. */ template -struct TransformImpl -{ +struct TransformImpl { template - static void Transform(TOut *out, const TIn *const in, const int stride, - const int y0, const int ymax, const int x0, const int xmax) - { + static void Transform(TOut* out, const TIn* const in, const int stride, + const int y0, const int ymax, const int x0, const int xmax) { const int n_whole_y_blocks = (ymax - y0) / IntBy; - const int y_remainders = (ymax - y0) % IntBy; - const int n_y_blocks = n_whole_y_blocks + (y_remainders ? 1 : 0); + const int y_remainders = (ymax - y0) % IntBy; + const int n_y_blocks = n_whole_y_blocks + (y_remainders ? 1 : 0); const int n_whole_x_blocks = (xmax - x0) / BlockBy; - const int x_remainders = (xmax - x0) % BlockBy; - const int n_x_blocks = n_whole_x_blocks + (x_remainders ? 1 : 0); + const int x_remainders = (xmax - x0) % BlockBy; + const int n_x_blocks = n_whole_x_blocks + (x_remainders ? 1 : 0); // "Y" loop: advance down the rows of the source IntBy rows at a time. // Set up fill_rows to show the number rows to copy from, and blank_rows // for the number of blank rows to add. - for(int y_block = 0; y_block < n_y_blocks; y_block++) - { - int fill_rows = (y_block < n_whole_y_blocks) ? IntBy : y_remainders; + for (int y_block=0 ; y_block < n_y_blocks; y_block++) { + int fill_rows = (y_block < n_whole_y_blocks) ? IntBy : y_remainders; int blank_rows = IntBy - fill_rows; int y_base = y0 + (y_block * IntBy); // So now advance along this block of rows, BlockBy columns at a time. - for(int x_block = 0; x_block < n_x_blocks; x_block++) - { - int fill_cols = (x_block < n_whole_x_blocks) ? BlockBy : x_remainders; + for (int x_block=0 ; x_block < n_x_blocks; x_block++) { + int fill_cols = (x_block < n_whole_x_blocks) ? BlockBy : x_remainders; int blank_cols = BlockBy - fill_cols; int x_base = x0 + (x_block * BlockBy); - for(int row = 0; row < fill_rows; row++) - { - for(int col = 0; col < fill_cols; col++) - { + for (int row = 0; row < fill_rows; row++) { + for (int col = 0; col < fill_cols; col++) { // In-range copy. If it's transposed, we reverse the sense of rows and columns here. - if(Transposed) - { + if (Transposed) { *out++ = static_cast(in[(x_base + col) * stride + y_base + row]); - } - else - { + } else { *out++ = static_cast(in[(y_base + row) * stride + x_base + col]); } } // "col" tail - row is in range but column is out of range. - for(int col = 0; col < blank_cols; col++) - { + for (int col=0; col < blank_cols; col++) { *out++ = static_cast(0); } } // "row" tail - row is out of range so fill with zeros always. - for(int row = 0; row < blank_rows; row++) - { - for(int col = 0; col < (fill_cols + blank_cols); col++) - { + for (int row = 0; row < blank_rows; row++) { + for (int col=0; col < (fill_cols + blank_cols); col++) { *out++ = static_cast(0); } } @@ -100,9 +88,8 @@ struct TransformImpl } template - static inline void Transform(T *out, const T *const in, const int stride, - const int k0, const int kmax, const int x0, const int xmax) - { + static inline void Transform(T* out, const T* const in, const int stride, + const int k0, const int kmax, const int x0, const int xmax) { Transform(out, in, stride, k0, kmax, x0, xmax); } }; @@ -110,13 +97,15 @@ struct TransformImpl /*****************************************************************************/ template void Transform( - TOut *out, const TIn *const in, const int stride, - const int k0, const int kmax, const int x0, const int xmax) -{ - // Redirect to a specialised implementation predicated on argument size. - TransformImpl::Transform( - out, in, stride, k0, kmax, x0, xmax); + TOut* out, const TIn* const in, const int stride, + const int k0, const int kmax, const int x0, const int xmax +) { + // Redirect to a specialised implementation predicated on argument size. + TransformImpl::Transform( + out, in, stride, k0, kmax, x0, xmax + ); } /*****************************************************************************/ #include "transforms/list.hpp" + diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp index 501d6bf075..e485ca7009 100644 --- a/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp +++ b/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp @@ -29,17 +29,15 @@ #include "../asmlib.hpp" -template <> -template -inline void TransformImpl<6, 1, false, 4, 4>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) -{ - uint32_t *outptr = reinterpret_cast(out); - const uint32_t *inptr = reinterpret_cast(in); +template<> +template +inline void TransformImpl<6, 1, false, 4, 4>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) { + uint32_t *outptr = reinterpret_cast(out); + const uint32_t *inptr = reinterpret_cast(in); uint32_t zerobuff[8]; - for(int y = y0; y < ymax; y += 6) - { + for (int y=y0; y::Transform(T *out, const T *in, int //prefetch_2x(inptr4); //prefetch_2x(inptr5); - int x = (kmax - k0); - for(; x > 7; x -= 8) - { + int x=(kmax-k0); + for (;x>7;x-=8) { /* Cope with ragged cases by copying from a buffer of zeroes instead */ - if((y + 5) >= ymax) - { - switch((y + 5) - ymax) - { + if ((y + 5) >= ymax) { + switch ((y + 5) - ymax) { /* Everything falls through in here */ case 4: inptr1 = zerobuff; @@ -80,67 +75,73 @@ inline void TransformImpl<6, 1, false, 4, 4>::Transform(T *out, const T *in, int } } - __asm __volatile( + + __asm __volatile ( // Load up 8 elements (2 vectors) from each of 8 sources. - "VLD1.32 {d0-d3}, [%[inptr0]]!\n" // q0=A0A1A2A3 - "VLD1.32 {d4-d7}, [%[inptr1]]!\n" // q2=B0B1B2B3 - "VLD1.32 {d8-d11}, [%[inptr2]]!\n" // q4=C0C1C2C3 - "VZIP.32 q0, q4\n" // q0=A0C0A1C1, q4 = A2C2A3C3 - "VLD1.32 {d12-d15}, [%[inptr3]]!\n" // q6=D0D1D2D3 - "VZIP.32 q2, q6\n" // q2=B0D0B1D1, q6 = B2D2B3D3 - "VLD1.32 {d16-d19}, [%[inptr4]]!\n" - "VLD1.32 {d20-d23}, [%[inptr5]]!\n" - "VZIP.32 q8, q10\n" // q8=E0F0E1F1, q10 = E2F2E3F3 + "VLD1.32 {d0-d3}, [%[inptr0]]!\n" // q0=A0A1A2A3 + "VLD1.32 {d4-d7}, [%[inptr1]]!\n" // q2=B0B1B2B3 + "VLD1.32 {d8-d11}, [%[inptr2]]!\n" // q4=C0C1C2C3 + "VZIP.32 q0, q4\n" // q0=A0C0A1C1, q4 = A2C2A3C3 + "VLD1.32 {d12-d15}, [%[inptr3]]!\n" // q6=D0D1D2D3 + "VZIP.32 q2, q6\n" // q2=B0D0B1D1, q6 = B2D2B3D3 + "VLD1.32 {d16-d19}, [%[inptr4]]!\n" + "VLD1.32 {d20-d23}, [%[inptr5]]!\n" + "VZIP.32 q8, q10\n" // q8=E0F0E1F1, q10 = E2F2E3F3 ASM_PREFETCH("[%[inptr0], #128]") - "VZIP.32 q0, q2\n" // q0 = A0B0C0D0, q2 = A1B1C1D1 + "VZIP.32 q0, q2\n" // q0 = A0B0C0D0, q2 = A1B1C1D1 // Store first elements - "VST1.32 {d0-d1}, [%[outptr]]!\n" - "VST1.32 {d16}, [%[outptr]]!\n" + "VST1.32 {d0-d1}, [%[outptr]]!\n" + "VST1.32 {d16}, [%[outptr]]!\n" - "VZIP.32 q4, q6\n" // q4 = A2B2C2D2, q6 = A3B3C3D3 + "VZIP.32 q4, q6\n" // q4 = A2B2C2D2, q6 = A3B3C3D3 // Store second elements - "VST1.32 {d4-d5}, [%[outptr]]!\n" - "VZIP.32 q1, q5\n" ASM_PREFETCH("[%[inptr1], #128]") - "VST1.32 {d17}, [%[outptr]]!\n" - "VZIP.32 q3, q7\n" + "VST1.32 {d4-d5}, [%[outptr]]!\n" + "VZIP.32 q1, q5\n" + ASM_PREFETCH("[%[inptr1], #128]") + "VST1.32 {d17}, [%[outptr]]!\n" + "VZIP.32 q3, q7\n" // Store third elements - "VZIP.32 q9, q11\n" - "VST1.32 {d8-d9}, [%[outptr]]!\n" - "VZIP.32 q1, q3\n" ASM_PREFETCH("[%[inptr2], #128]") - "VST1.32 {d20}, [%[outptr]]!\n" + "VZIP.32 q9, q11\n" + "VST1.32 {d8-d9}, [%[outptr]]!\n" + "VZIP.32 q1, q3\n" + ASM_PREFETCH("[%[inptr2], #128]") + "VST1.32 {d20}, [%[outptr]]!\n" // Store fourth elements - "VZIP.32 q5, q7\n" - "VST1.32 {d12-d13}, [%[outptr]]!\n" ASM_PREFETCH("[%[inptr3], #128]") - "VST1.32 {d21}, [%[outptr]]!\n" + "VZIP.32 q5, q7\n" + "VST1.32 {d12-d13}, [%[outptr]]!\n" + ASM_PREFETCH("[%[inptr3], #128]") + "VST1.32 {d21}, [%[outptr]]!\n" // Fifth - "VST1.32 {d2-d3}, [%[outptr]]!\n" ASM_PREFETCH("[%[inptr4], #128]") - "VST1.32 {d18}, [%[outptr]]!\n" + "VST1.32 {d2-d3}, [%[outptr]]!\n" + ASM_PREFETCH("[%[inptr4], #128]") + "VST1.32 {d18}, [%[outptr]]!\n" // Sixth - "VST1.32 {d6-d7}, [%[outptr]]!\n" ASM_PREFETCH("[%[inptr5], #128]") - "VST1.32 {d19}, [%[outptr]]!\n" + "VST1.32 {d6-d7}, [%[outptr]]!\n" + ASM_PREFETCH("[%[inptr5], #128]") + "VST1.32 {d19}, [%[outptr]]!\n" // Seventh - "VST1.32 {d10-d11}, [%[outptr]]!\n" - "VST1.32 {d22}, [%[outptr]]!\n" + "VST1.32 {d10-d11}, [%[outptr]]!\n" + "VST1.32 {d22}, [%[outptr]]!\n" // Eighth - "VST1.32 {d14-d15}, [%[outptr]]!\n" - "VST1.32 {d23}, [%[outptr]]!\n" + "VST1.32 {d14-d15}, [%[outptr]]!\n" + "VST1.32 {d23}, [%[outptr]]!\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), [outptr] "+r"(outptr) + : [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), + [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5), [outptr] "+r" (outptr) : - : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12"); + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12" + ); } - for(; x > 0; x--) - { + for (;x>0;x--) { *outptr++ = *inptr0++; *outptr++ = *inptr1++; *outptr++ = *inptr2++; @@ -151,4 +152,4 @@ inline void TransformImpl<6, 1, false, 4, 4>::Transform(T *out, const T *in, int } } -#endif // __arm__ +#endif // __arm__ diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a32_transpose_interleave_8way_32bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a32_transpose_interleave_8way_32bit.hpp index ea32c9665c..a7e17fa074 100644 --- a/src/core/NEON/kernels/arm_gemm/transforms/a32_transpose_interleave_8way_32bit.hpp +++ b/src/core/NEON/kernels/arm_gemm/transforms/a32_transpose_interleave_8way_32bit.hpp @@ -31,86 +31,97 @@ template <> template inline void TransformImpl<8, 1, true, 4, 4>::Transform( - T *out, const T *const in, const int stride, - const int x0, const int xmax, const int k0, const int kmax) -{ - // Redirect to a 16x uint16_t specialisation - TransformImpl<16, 1, true, 2, 2>::Transform( - reinterpret_cast(out), - reinterpret_cast(in), - stride * 2, x0 * 2, xmax * 2, k0, kmax); + T* out, const T* const in, const int stride, + const int x0, const int xmax, const int k0, const int kmax +) { + // Redirect to a 16x uint16_t specialisation + TransformImpl<16, 1, true, 2, 2>::Transform( + reinterpret_cast(out), + reinterpret_cast(in), + stride*2, x0*2, xmax*2, k0, kmax + ); } // Generic 12x16-bit sized specialisation template <> template inline void TransformImpl<16, 1, true, 2, 2>::Transform( - T *out, const T *const in, const int stride, - const int x0, const int xmax, const int k0, const int kmax) -{ - // Redirect to a uint16_t specialisation - Transform( - reinterpret_cast(out), - reinterpret_cast(in), - stride, x0, xmax, k0, kmax); + T* out, const T* const in, const int stride, + const int x0, const int xmax, const int k0, const int kmax +) { + // Redirect to a uint16_t specialisation + Transform( + reinterpret_cast(out), + reinterpret_cast(in), + stride, x0, xmax, k0, kmax + ); } // Specialised 16 x uint16_t version template <> -inline void TransposeInterleaveCommon<16, uint16_t, uint16_t>::moveblock_1x1(const uint16_t *&in0, uint16_t *out) -{ - __asm volatile( - "VLD1.32 {d0-d3}, [%[in0]]!\n" - "VST1.32 {d0-d3}, [%[out]]\n" ASM_PREFETCH("[%[in0], #192]") - : [in0] "+r"(in0), - [out] "+r"(out) - : - : "q0", "q1", "memory"); +inline void TransposeInterleaveCommon<16, uint16_t, uint16_t>::moveblock_1x1(const uint16_t *&in0, uint16_t *out) { + __asm volatile ( + "VLD1.32 {d0-d3}, [%[in0]]!\n" + "VST1.32 {d0-d3}, [%[out]]\n" + ASM_PREFETCH("[%[in0], #192]") + : [in0] "+r" (in0), + [out] "+r" (out) + : + : "q0", "q1", "memory" + ); } template <> -inline void TransposeInterleaveCommon<16, uint16_t, uint16_t>::moveblock_1x2(const uint16_t *&in0, const uint16_t *&in1, uint16_t *out) -{ - __asm volatile( - "VLD1.32 {d0-d3}, [%[in0]]!\n" - "VST1.32 {d0-d3}, [%[out]]!\n" ASM_PREFETCH("[%[in0], #192]") - "VLD1.32 {d0-d3}, [%[in1]]!\n" - "VST1.32 {d0-d3}, [%[out]]\n" ASM_PREFETCH("[%[in1], #192]") "SUB %[out], %[out], #32\n" - : [in0] "+r"(in0), - [in1] "+r"(in1), - [out] "+r"(out) - : - : "q0", "q1", "memory"); +inline void TransposeInterleaveCommon<16, uint16_t, uint16_t>::moveblock_1x2(const uint16_t *&in0, const uint16_t *&in1, uint16_t *out) { + __asm volatile ( + "VLD1.32 {d0-d3}, [%[in0]]!\n" + "VST1.32 {d0-d3}, [%[out]]!\n" + ASM_PREFETCH("[%[in0], #192]") + "VLD1.32 {d0-d3}, [%[in1]]!\n" + "VST1.32 {d0-d3}, [%[out]]\n" + ASM_PREFETCH("[%[in1], #192]") + "SUB %[out], %[out], #32\n" + : [in0] "+r" (in0), + [in1] "+r" (in1), + [out] "+r" (out) + : + : "q0", "q1", "memory" + ); } template <> -inline void TransposeInterleaveCommon<16, uint16_t, uint16_t>::moveblock_1x4(const uint16_t *&in0, const uint16_t *&in1, const uint16_t *&in2, const uint16_t *&in3, uint16_t *out) -{ - __asm __volatile( - "VLD1.32 {d0-d3}, [%[in0]]!\n" - "VST1.32 {d0-d3}, [%[out]]!\n" ASM_PREFETCH("[%[in0], #192]") - "VLD1.32 {d0-d3}, [%[in1]]!\n" - "VST1.32 {d0-d3}, [%[out]]!\n" ASM_PREFETCH("[%[in1], #192]") - "VLD1.32 {d0-d3}, [%[in2]]!\n" - "VST1.32 {d0-d3}, [%[out]]!\n" ASM_PREFETCH("[%[in2], #192]") - "VLD1.32 {d0-d3}, [%[in3]]!\n" - "VST1.32 {d0-d3}, [%[out]]\n" ASM_PREFETCH("[%[in3], #192]") "SUB %[out], %[out], #96\n" - : [in0] "+r"(in0), - [in1] "+r"(in1), - [in2] "+r"(in2), - [in3] "+r"(in3), - [out] "+r"(out) - : - : "q0", "q1", "memory"); +inline void TransposeInterleaveCommon<16, uint16_t, uint16_t>::moveblock_1x4(const uint16_t *&in0, const uint16_t *&in1, const uint16_t *&in2, const uint16_t *&in3, uint16_t *out) { + __asm __volatile ( + "VLD1.32 {d0-d3}, [%[in0]]!\n" + "VST1.32 {d0-d3}, [%[out]]!\n" + ASM_PREFETCH("[%[in0], #192]") + "VLD1.32 {d0-d3}, [%[in1]]!\n" + "VST1.32 {d0-d3}, [%[out]]!\n" + ASM_PREFETCH("[%[in1], #192]") + "VLD1.32 {d0-d3}, [%[in2]]!\n" + "VST1.32 {d0-d3}, [%[out]]!\n" + ASM_PREFETCH("[%[in2], #192]") + "VLD1.32 {d0-d3}, [%[in3]]!\n" + "VST1.32 {d0-d3}, [%[out]]\n" + ASM_PREFETCH("[%[in3], #192]") + "SUB %[out], %[out], #96\n" + : [in0] "+r" (in0), + [in1] "+r" (in1), + [in2] "+r" (in2), + [in3] "+r" (in3), + [out] "+r" (out) + : + : "q0", "q1", "memory" + ); } template <> template <> inline void TransformImpl<16, 1, true, 2, 2>::Transform( - uint16_t *out, const uint16_t *const in, const int stride, - const int x0, const int xmax, const int k0, const int kmax) -{ - TransposeInterleaveCommon<16, uint16_t, uint16_t>::Transform(out, in, stride, x0, xmax, k0, kmax); + uint16_t* out, const uint16_t* const in, const int stride, + const int x0, const int xmax, const int k0, const int kmax +) { + TransposeInterleaveCommon<16, uint16_t, uint16_t>::Transform(out, in, stride, x0, xmax, k0, kmax); } #endif // __arm__ diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp index 8d61f15cec..7e61f425d4 100644 --- a/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp @@ -30,17 +30,15 @@ #include "../asmlib.hpp" #include "../utils.hpp" -template <> -template -void TransformImpl<4, 16, false, 1, 1>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) -{ - uint8_t *outptr = (uint8_t *)out; - const uint8_t *inptr = (uint8_t *)in; +template<> +template +void TransformImpl<4, 16, false, 1, 1>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) { + uint8_t *outptr = (uint8_t *)out; + const uint8_t *inptr = (uint8_t *)in; uint8_t zerobuff[16]; - for(int y = y0; y < ymax; y += 4) - { + for (int y=y0; y::Transform(T *out, const T *in, int ldin, prefetch_2x(inptr2); prefetch_2x(inptr3); - int x = (kmax - k0); - for(; x > 15; x -= 16) - { + int x=(kmax-k0); + for (;x>15;x-=16) { /* Cope with ragged cases by copying from a buffer of zeroes instead */ - if((y + 3) >= ymax) - { - switch((y + 3) - ymax) - { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { /* Everything falls through in here */ case 2: inptr1 = zerobuff; @@ -73,23 +68,28 @@ void TransformImpl<4, 16, false, 1, 1>::Transform(T *out, const T *in, int ldin, } } - __asm __volatile( - "LDR q0, [%[inptr0]], #16\n" ASM_PREFETCH("[%[inptr0], #176]") "LDR q1, [%[inptr1]], #16\n" ASM_PREFETCH("[%[inptr1], #176]") - "STP q0, q1, [%[outptr]], #32\n" - "LDR q0, [%[inptr2]], #16\n" ASM_PREFETCH("[%[inptr2], #176]") "LDR q1, [%[inptr3]], #16\n" ASM_PREFETCH("[%[inptr3], #176]") "STP q0, q1, [%[outptr]], #32\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [outptr] "+r"(outptr) + __asm __volatile ( + "LDR q0, [%[inptr0]], #16\n" + ASM_PREFETCH("[%[inptr0], #176]") + "LDR q1, [%[inptr1]], #16\n" + ASM_PREFETCH("[%[inptr1], #176]") + "STP q0, q1, [%[outptr]], #32\n" + "LDR q0, [%[inptr2]], #16\n" + ASM_PREFETCH("[%[inptr2], #176]") + "LDR q1, [%[inptr3]], #16\n" + ASM_PREFETCH("[%[inptr3], #176]") + "STP q0, q1, [%[outptr]], #32\n" + : [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), + [outptr] "+r" (outptr) : - : "v0", "v1"); + : "v0", "v1" + ); } - if(x > 0) - { + if (x>0) { /* Need to duplicate this here, in case we didn't run the main loop. */ - if((y + 3) >= ymax) - { - switch((y + 3) - ymax) - { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { /* Everything falls through in here */ case 2: inptr1 = zerobuff; @@ -105,16 +105,11 @@ void TransformImpl<4, 16, false, 1, 1>::Transform(T *out, const T *in, int ldin, } /* We have to write out 16 values, copy as many legal values as there are and pad with 0 */ - auto f = [&outptr, x](const uint8_t *&p) - { - for(int i = 0; i < 16; i++) - { - if(i < x) - { + auto f = [&outptr, x](const uint8_t *&p) { + for (int i=0; i<16; i++) { + if (i < x) { *outptr++ = *p++; - } - else - { + } else { *outptr++ = 0; } } @@ -128,4 +123,4 @@ void TransformImpl<4, 16, false, 1, 1>::Transform(T *out, const T *in, int ldin, } } -#endif // __aarch64__ \ No newline at end of file +#endif // __aarch64__ \ No newline at end of file diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp index 3cbc8815e3..99bb2d66bd 100644 --- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp @@ -29,17 +29,15 @@ #include "../asmlib.hpp" -template <> -template -void TransformImpl<8, 1, false, 2, 2>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) -{ - uint16_t *outptr = (uint16_t *)out; - const uint16_t *inptr = (const uint16_t *)in; +template<> +template +void TransformImpl<8, 1, false, 2, 2>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) { + uint16_t *outptr = (uint16_t *)out; + const uint16_t *inptr = (const uint16_t *)in; uint16_t zerobuff[24]; - for(int y = y0; y < ymax; y += 8) - { + for (int y=y0; y::Transform(T *out, const T *in, int ldin, prefetch_2x(inptr6); prefetch_2x(inptr7); - int x = (kmax - k0); - for(; x > 7; x -= 8) - { + int x=(kmax-k0); + for (;x>7;x-=8) { /* Cope with ragged cases by copying from a buffer of zeroes instead */ - if((y + 7) >= ymax) - { - switch((y + 7) - ymax) - { + if ((y + 7) >= ymax) { + switch ((y + 7) - ymax) { /* Everything falls through in here */ case 6: inptr1 = zerobuff; @@ -89,72 +84,74 @@ void TransformImpl<8, 1, false, 2, 2>::Transform(T *out, const T *in, int ldin, } int skippf = (x & 31); - __asm __volatile( + __asm __volatile ( // Load up 8 elements (1 vector) from each of 8 sources. - "CBNZ %w[skippf], 1f\n" ASM_PREFETCH("[%[inptr0], #128]") + "CBNZ %w[skippf], 1f\n" + ASM_PREFETCH("[%[inptr0], #128]") ASM_PREFETCH("[%[inptr1], #128]") ASM_PREFETCH("[%[inptr2], #128]") ASM_PREFETCH("[%[inptr3], #128]") "1:\n" - "LDR q0, [%[inptr0]], #16\n" // q0=A0A1A2A3A4A5A6A7 - "LDR q4, [%[inptr4]], #16\n" // q8=E0E1E2E3E4E5E6E7 - "LDR q2, [%[inptr2]], #16\n" // q4=C0C1C2C3... - "LDR q6, [%[inptr6]], #16\n" - "ZIP1 v8.8h, v0.8h, v4.8h\n" // q8=A0E0A1E1A2E2A3E3 - "ZIP2 v16.8h, v0.8h, v4.8h\n" // q16=A4E4A5E5A6E6A7E7 - "ZIP1 v9.8h, v2.8h, v6.8h\n" // q9=C0G0C1G1C2G2C3G3 - "ZIP2 v17.8h, v2.8h, v6.8h\n" // q17=C4G4C5G5C6G6C7G7 - "LDR q1, [%[inptr1]], #16\n" // q1=B0B1B2B3B4B5B6B7 - "LDR q5, [%[inptr5]], #16\n" - "LDR q3, [%[inptr3]], #16\n" // q3=D0D1D2D3.... - "LDR q7, [%[inptr7]], #16\n" - "ZIP1 v10.8h, v1.8h, v5.8h\n" // q18=B0F0B1F1B2F2B3F3 - "ZIP2 v18.8h, v1.8h, v5.8h\n" // q18=B4F4B5F5B6F6B7F7 - "ZIP1 v11.8h, v3.8h, v7.8h\n" // q19=D0H0D1H1D2H2D3H3 - "ZIP2 v19.8h, v3.8h, v7.8h\n" // q19=D4H4D5H5D6H6D7H7 - - "ZIP1 v12.8h, v8.8h, v9.8h\n" // q20=A0C0E0G0A1C1E1G1 - "ZIP2 v20.8h, v8.8h, v9.8h\n" - "ZIP1 v13.8h, v10.8h, v11.8h\n" // q21=B0D0F0H0B1I1F1H1 - "ZIP2 v21.8h, v10.8h, v11.8h\n" - - "CBNZ %w[skippf], 2f\n" ASM_PREFETCH("[%[inptr4], #112]") + "LDR q0, [%[inptr0]], #16\n" // q0=A0A1A2A3A4A5A6A7 + "LDR q4, [%[inptr4]], #16\n" // q8=E0E1E2E3E4E5E6E7 + "LDR q2, [%[inptr2]], #16\n" // q4=C0C1C2C3... + "LDR q6, [%[inptr6]], #16\n" + "ZIP1 v8.8h, v0.8h, v4.8h\n" // q8=A0E0A1E1A2E2A3E3 + "ZIP2 v16.8h, v0.8h, v4.8h\n" // q16=A4E4A5E5A6E6A7E7 + "ZIP1 v9.8h, v2.8h, v6.8h\n" // q9=C0G0C1G1C2G2C3G3 + "ZIP2 v17.8h, v2.8h, v6.8h\n" // q17=C4G4C5G5C6G6C7G7 + "LDR q1, [%[inptr1]], #16\n" // q1=B0B1B2B3B4B5B6B7 + "LDR q5, [%[inptr5]], #16\n" + "LDR q3, [%[inptr3]], #16\n" // q3=D0D1D2D3.... + "LDR q7, [%[inptr7]], #16\n" + "ZIP1 v10.8h, v1.8h, v5.8h\n" // q18=B0F0B1F1B2F2B3F3 + "ZIP2 v18.8h, v1.8h, v5.8h\n" // q18=B4F4B5F5B6F6B7F7 + "ZIP1 v11.8h, v3.8h, v7.8h\n" // q19=D0H0D1H1D2H2D3H3 + "ZIP2 v19.8h, v3.8h, v7.8h\n" // q19=D4H4D5H5D6H6D7H7 + + "ZIP1 v12.8h, v8.8h, v9.8h\n" // q20=A0C0E0G0A1C1E1G1 + "ZIP2 v20.8h, v8.8h, v9.8h\n" + "ZIP1 v13.8h, v10.8h, v11.8h\n" // q21=B0D0F0H0B1I1F1H1 + "ZIP2 v21.8h, v10.8h, v11.8h\n" + + "CBNZ %w[skippf], 2f\n" + ASM_PREFETCH("[%[inptr4], #112]") ASM_PREFETCH("[%[inptr5], #112]") ASM_PREFETCH("[%[inptr6], #112]") ASM_PREFETCH("[%[inptr7], #112]") "2:\n" - "ZIP1 v22.8h, v16.8h, v17.8h\n" - "ZIP2 v30.8h, v16.8h, v17.8h\n" - "ZIP1 v23.8h, v18.8h, v19.8h\n" - "ZIP2 v31.8h, v18.8h, v19.8h\n" - - "ZIP1 v14.8h, v12.8h, v13.8h\n" // q22=A0B0C0D0E0F0G0H0 - "ZIP2 v15.8h, v12.8h, v13.8h\n" // q23=A1B1C1D1E1F1G1H1 - "STP q14, q15, [%[outptr]], #32\n" // Write back first two elements - - "ZIP1 v0.8h, v20.8h, v21.8h\n" - "ZIP2 v1.8h, v20.8h, v21.8h\n" - "STP q0, q1, [%[outptr]], #32\n" // Write back next two elements - - "ZIP1 v2.8h, v22.8h, v23.8h\n" - "ZIP2 v3.8h, v22.8h, v23.8h\n" - "STP q2, q3, [%[outptr]], #32\n" // Write back next two elements - - "ZIP1 v4.8h, v30.8h, v31.8h\n" - "ZIP2 v5.8h, v30.8h, v31.8h\n" - "STP q4, q5, [%[outptr]], #32\n" // Write back last two elements - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) - : [skippf] "r"(skippf) + "ZIP1 v22.8h, v16.8h, v17.8h\n" + "ZIP2 v30.8h, v16.8h, v17.8h\n" + "ZIP1 v23.8h, v18.8h, v19.8h\n" + "ZIP2 v31.8h, v18.8h, v19.8h\n" + + "ZIP1 v14.8h, v12.8h, v13.8h\n" // q22=A0B0C0D0E0F0G0H0 + "ZIP2 v15.8h, v12.8h, v13.8h\n" // q23=A1B1C1D1E1F1G1H1 + "STP q14, q15, [%[outptr]], #32\n" // Write back first two elements + + "ZIP1 v0.8h, v20.8h, v21.8h\n" + "ZIP2 v1.8h, v20.8h, v21.8h\n" + "STP q0, q1, [%[outptr]], #32\n" // Write back next two elements + + "ZIP1 v2.8h, v22.8h, v23.8h\n" + "ZIP2 v3.8h, v22.8h, v23.8h\n" + "STP q2, q3, [%[outptr]], #32\n" // Write back next two elements + + "ZIP1 v4.8h, v30.8h, v31.8h\n" + "ZIP2 v5.8h, v30.8h, v31.8h\n" + "STP q4, q5, [%[outptr]], #32\n" // Write back last two elements + : [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), + [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5), [inptr6] "+r" (inptr6), [inptr7] "+r" (inptr7), [outptr] "+r" (outptr) + : [skippf] "r" (skippf) : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", - "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", + "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); } - for(; x > 0; x--) - { + for (;x>0;x--) { *outptr++ = *inptr0++; *outptr++ = *inptr1++; *outptr++ = *inptr2++; diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp index 47e4fa2608..83391cc59f 100644 --- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp @@ -29,17 +29,15 @@ #include "../asmlib.hpp" -template <> -template -inline void TransformImpl<8, 1, false, 4, 4>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) -{ - uint32_t *outptr = (uint32_t *)out; - const uint32_t *inptr = (uint32_t *)in; +template<> +template +inline void TransformImpl<8, 1, false, 4, 4>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) { + uint32_t *outptr = (uint32_t *)out; + const uint32_t *inptr = (uint32_t *)in; uint32_t zerobuff[8]; - for(int y = y0; y < ymax; y += 8) - { + for (int y=y0; y::Transform(T *out, const T *in, int prefetch_2x(inptr6); prefetch_2x(inptr7); - int x = (kmax - k0); - for(; x > 7; x -= 8) - { + int x=(kmax-k0); + for (;x>7;x-=8) { /* Cope with ragged cases by copying from a buffer of zeroes instead */ - if((y + 7) >= ymax) - { - switch((y + 7) - ymax) - { + if ((y + 7) >= ymax) { + switch ((y + 7) - ymax) { /* Everything falls through in here */ case 6: inptr1 = zerobuff; @@ -88,19 +83,20 @@ inline void TransformImpl<8, 1, false, 4, 4>::Transform(T *out, const T *in, int } } - __asm __volatile( + __asm __volatile ( // Load up 8 elements (2 vectors) from each of 8 sources. "LDP q0, q1, [%[inptr0]], #32\n" // q0=A0A1A2A3 "LDP q2, q3, [%[inptr1]], #32\n" // q2=B0B1B2B3 "LDP q4, q5, [%[inptr2]], #32\n" // q4=C0C1C2C3 - "ZIP1 v16.4s, v0.4s, v4.4s\n" // q16=A0C0A1C1 + "ZIP1 v16.4s, v0.4s, v4.4s\n" // q16=A0C0A1C1 ASM_PREFETCH("[%[inptr0], #128]") "LDP q6, q7, [%[inptr3]], #32\n" // q6=D0D1D2D3 - "ZIP1 v17.4s, v2.4s, v6.4s\n" // q17=B0D0B1D1 + "ZIP1 v17.4s, v2.4s, v6.4s\n" // q17=B0D0B1D1 "LDP q8, q9, [%[inptr4]], #32\n" "LDP q10, q11, [%[inptr5]], #32\n" "LDP q12, q13, [%[inptr6]], #32\n" - "ZIP1 v18.4s, v8.4s, v12.4s\n" ASM_PREFETCH("[%[inptr1], #128]") + "ZIP1 v18.4s, v8.4s, v12.4s\n" + ASM_PREFETCH("[%[inptr1], #128]") "LDP q14, q15, [%[inptr7]], #32\n" "ZIP1 v19.4s, v10.4s, v14.4s\n" @@ -110,7 +106,8 @@ inline void TransformImpl<8, 1, false, 4, 4>::Transform(T *out, const T *in, int "ZIP2 v22.4s, v16.4s, v17.4s\n" "ZIP2 v23.4s, v18.4s, v19.4s\n" - "ZIP2 v16.4s, v0.4s, v4.4s\n" ASM_PREFETCH("[%[inptr3], #128]") + "ZIP2 v16.4s, v0.4s, v4.4s\n" + ASM_PREFETCH("[%[inptr3], #128]") "ZIP2 v17.4s, v2.4s, v6.4s\n" "STP q20, q21, [%[outptr]], #32\n" // Write back the first element of each source @@ -118,12 +115,14 @@ inline void TransformImpl<8, 1, false, 4, 4>::Transform(T *out, const T *in, int "ZIP2 v19.4s, v10.4s, v14.4s\n" "STP q22, q23, [%[outptr]], #32\n" // Write back the second element of each source - "ZIP1 v20.4s, v16.4s, v17.4s\n" ASM_PREFETCH("[%[inptr4], #128]") + "ZIP1 v20.4s, v16.4s, v17.4s\n" + ASM_PREFETCH("[%[inptr4], #128]") "ZIP1 v21.4s, v18.4s, v19.4s\n" "ZIP2 v22.4s, v16.4s, v17.4s\n" "ZIP2 v23.4s, v18.4s, v19.4s\n" - "ZIP1 v16.4s, v1.4s, v5.4s\n" ASM_PREFETCH("[%[inptr5], #128]") + "ZIP1 v16.4s, v1.4s, v5.4s\n" + ASM_PREFETCH("[%[inptr5], #128]") "ZIP1 v17.4s, v3.4s, v7.4s\n" "STP q20, q21, [%[outptr]], #32\n" // Third element @@ -133,14 +132,16 @@ inline void TransformImpl<8, 1, false, 4, 4>::Transform(T *out, const T *in, int "ZIP1 v20.4s, v16.4s, v17.4s\n" "ZIP1 v21.4s, v18.4s, v19.4s\n" - "ZIP2 v22.4s, v16.4s, v17.4s\n" ASM_PREFETCH("[%[inptr6], #128]") + "ZIP2 v22.4s, v16.4s, v17.4s\n" + ASM_PREFETCH("[%[inptr6], #128]") "ZIP2 v23.4s, v18.4s, v19.4s\n" "ZIP2 v16.4s, v1.4s, v5.4s\n" "ZIP2 v17.4s, v3.4s, v7.4s\n" "STP q20, q21, [%[outptr]], #32\n" // Fifth element - "ZIP2 v18.4s, v9.4s, v13.4s\n" ASM_PREFETCH("[%[inptr7], #128]") + "ZIP2 v18.4s, v9.4s, v13.4s\n" + ASM_PREFETCH("[%[inptr7], #128]") "ZIP2 v19.4s, v11.4s, v15.4s\n" "STP q22, q23, [%[outptr]], #32\n" // Sixth element @@ -151,15 +152,15 @@ inline void TransformImpl<8, 1, false, 4, 4>::Transform(T *out, const T *in, int "ZIP2 v22.4s, v16.4s, v17.4s\n" "ZIP2 v23.4s, v18.4s, v19.4s\n" "STP q22, q23, [%[outptr]], #32\n" // Eighth element - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) + : [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), + [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5), [inptr6] "+r" (inptr6), [inptr7] "+r" (inptr7), [outptr] "+r" (outptr) : : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + ); } - for(; x > 0; x--) - { + for (;x>0;x--) { *outptr++ = *inptr0++; *outptr++ = *inptr1++; *outptr++ = *inptr2++; @@ -172,4 +173,4 @@ inline void TransformImpl<8, 1, false, 4, 4>::Transform(T *out, const T *in, int } } -#endif // __aarch64__ +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp index 1d2d4969f6..fd812165fd 100644 --- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp @@ -29,17 +29,15 @@ #include "../asmlib.hpp" -template <> -template <> -inline void TransformImpl<8, 1, false, 4, 2>::Transform(float *out, const __fp16 *in, int ldin, int y0, int ymax, int k0, int kmax) -{ - float *outptr = out; - const __fp16 *inptr = in; +template<> +template<> +inline void TransformImpl<8, 1, false, 4, 2>::Transform(float *out, const __fp16 *in, int ldin, int y0, int ymax, int k0, int kmax) { + float *outptr = out; + const __fp16 *inptr = in; __fp16 zerobuff[8]; - for(int y = y0; y < ymax; y += 8) - { + for (int y=y0; y::Transform(float *out, const __fp16 prefetch_2x(inptr6); prefetch_2x(inptr7); - int x = (kmax - k0); - for(; x > 7; x -= 8) - { + int x=(kmax-k0); + for (;x>7;x-=8) { /* Cope with ragged cases by copying from a buffer of zeroes instead */ - if((y + 7) >= ymax) - { - switch((y + 7) - ymax) - { + if ((y + 7) >= ymax) { + switch ((y + 7) - ymax) { /* Everything falls through in here */ case 6: inptr1 = zerobuff; @@ -88,95 +83,100 @@ inline void TransformImpl<8, 1, false, 4, 2>::Transform(float *out, const __fp16 } } - __asm __volatile( + __asm __volatile ( // Load up 8 elements (2 vectors) from each of 8 sources. - "LDR q0, [%[inptr0]], #16\n" - "LDR q2, [%[inptr1]], #16\n" - "FCVTL2 v1.4s, v0.8h\n" - "FCVTL v0.4s, v0.4h\n" - "LDR q4, [%[inptr2]], #16\n" // q4=C0C1C2C3 - "FCVTL2 v3.4s, v2.8h\n" - "FCVTL v2.4s, v2.4h\n" - "FCVTL2 v5.4s, v4.8h\n" - "FCVTL v4.4s, v4.4h\n" - "ZIP1 v16.4s, v0.4s, v4.4s\n" // q16=A0C0A1C1 + "LDR q0, [%[inptr0]], #16\n" + "LDR q2, [%[inptr1]], #16\n" + "FCVTL2 v1.4s, v0.8h\n" + "FCVTL v0.4s, v0.4h\n" + "LDR q4, [%[inptr2]], #16\n" // q4=C0C1C2C3 + "FCVTL2 v3.4s, v2.8h\n" + "FCVTL v2.4s, v2.4h\n" + "FCVTL2 v5.4s, v4.8h\n" + "FCVTL v4.4s, v4.4h\n" + "ZIP1 v16.4s, v0.4s, v4.4s\n" // q16=A0C0A1C1 ASM_PREFETCH("[%[inptr0], #128]") - "LDR q6, [%[inptr3]], #16\n" // q6=D0D1D2D3 - "FCVTL2 v7.4s, v6.8h\n" - "FCVTL v6.4s, v6.4h\n" - "ZIP1 v17.4s, v2.4s, v6.4s\n" // q17=B0D0B1D1 - "LDR q8, [%[inptr4]], #16\n" - "LDR q10, [%[inptr5]], #16\n" - "FCVTL2 v9.4s, v8.8h\n" - "FCVTL v8.4s, v8.4h\n" ASM_PREFETCH("[%[inptr1], #128]") - "LDR q12, [%[inptr6]], #16\n" - "FCVTL2 v11.4s, v10.8h\n" - "FCVTL v10.4s, v10.4h\n" - "FCVTL2 v13.4s, v12.8h\n" - "FCVTL v12.4s, v12.4h\n" - "ZIP1 v18.4s, v8.4s, v12.4s\n" - "LDR q14, [%[inptr7]], #16\n" - "FCVTL2 v15.4s, v14.8h\n" - "FCVTL v14.4s, v14.4h\n" - "ZIP1 v19.4s, v10.4s, v14.4s\n" + "LDR q6, [%[inptr3]], #16\n" // q6=D0D1D2D3 + "FCVTL2 v7.4s, v6.8h\n" + "FCVTL v6.4s, v6.4h\n" + "ZIP1 v17.4s, v2.4s, v6.4s\n" // q17=B0D0B1D1 + "LDR q8, [%[inptr4]], #16\n" + "LDR q10, [%[inptr5]], #16\n" + "FCVTL2 v9.4s, v8.8h\n" + "FCVTL v8.4s, v8.4h\n" + ASM_PREFETCH("[%[inptr1], #128]") + "LDR q12, [%[inptr6]], #16\n" + "FCVTL2 v11.4s, v10.8h\n" + "FCVTL v10.4s, v10.4h\n" + "FCVTL2 v13.4s, v12.8h\n" + "FCVTL v12.4s, v12.4h\n" + "ZIP1 v18.4s, v8.4s, v12.4s\n" + "LDR q14, [%[inptr7]], #16\n" + "FCVTL2 v15.4s, v14.8h\n" + "FCVTL v14.4s, v14.4h\n" + "ZIP1 v19.4s, v10.4s, v14.4s\n" ASM_PREFETCH("[%[inptr2], #128]") - "ZIP1 v20.4s, v16.4s, v17.4s\n" // q20=A0B0C0D0 - "ZIP1 v21.4s, v18.4s, v19.4s\n" - "ZIP2 v22.4s, v16.4s, v17.4s\n" - "ZIP2 v23.4s, v18.4s, v19.4s\n" ASM_PREFETCH("[%[inptr3], #128]") - - "ZIP2 v16.4s, v0.4s, v4.4s\n" - "ZIP2 v17.4s, v2.4s, v6.4s\n" - "STP q20, q21, [%[outptr]], #32\n" // Write back the first element of each source - - "ZIP2 v18.4s, v8.4s, v12.4s\n" ASM_PREFETCH("[%[inptr4], #128]") - "ZIP2 v19.4s, v10.4s, v14.4s\n" - "STP q22, q23, [%[outptr]], #32\n" // Write back the second element of each source - - "ZIP1 v20.4s, v16.4s, v17.4s\n" - "ZIP1 v21.4s, v18.4s, v19.4s\n" ASM_PREFETCH("[%[inptr5], #128]") - "ZIP2 v22.4s, v16.4s, v17.4s\n" - "ZIP2 v23.4s, v18.4s, v19.4s\n" - - "ZIP1 v16.4s, v1.4s, v5.4s\n" - "ZIP1 v17.4s, v3.4s, v7.4s\n" ASM_PREFETCH("[%[inptr6], #128]") - "STP q20, q21, [%[outptr]], #32\n" // Third element - - "ZIP1 v18.4s, v9.4s, v13.4s\n" - "ZIP1 v19.4s, v11.4s, v15.4s\n" - "STP q22, q23, [%[outptr]], #32\n" // Fourth element + "ZIP1 v20.4s, v16.4s, v17.4s\n" // q20=A0B0C0D0 + "ZIP1 v21.4s, v18.4s, v19.4s\n" + "ZIP2 v22.4s, v16.4s, v17.4s\n" + "ZIP2 v23.4s, v18.4s, v19.4s\n" + ASM_PREFETCH("[%[inptr3], #128]") + + "ZIP2 v16.4s, v0.4s, v4.4s\n" + "ZIP2 v17.4s, v2.4s, v6.4s\n" + "STP q20, q21, [%[outptr]], #32\n" // Write back the first element of each source + + "ZIP2 v18.4s, v8.4s, v12.4s\n" + ASM_PREFETCH("[%[inptr4], #128]") + "ZIP2 v19.4s, v10.4s, v14.4s\n" + "STP q22, q23, [%[outptr]], #32\n" // Write back the second element of each source + + "ZIP1 v20.4s, v16.4s, v17.4s\n" + "ZIP1 v21.4s, v18.4s, v19.4s\n" + ASM_PREFETCH("[%[inptr5], #128]") + "ZIP2 v22.4s, v16.4s, v17.4s\n" + "ZIP2 v23.4s, v18.4s, v19.4s\n" + + "ZIP1 v16.4s, v1.4s, v5.4s\n" + "ZIP1 v17.4s, v3.4s, v7.4s\n" + ASM_PREFETCH("[%[inptr6], #128]") + "STP q20, q21, [%[outptr]], #32\n" // Third element + + "ZIP1 v18.4s, v9.4s, v13.4s\n" + "ZIP1 v19.4s, v11.4s, v15.4s\n" + "STP q22, q23, [%[outptr]], #32\n" // Fourth element ASM_PREFETCH("[%[inptr7], #128]") - "ZIP1 v20.4s, v16.4s, v17.4s\n" - "ZIP1 v21.4s, v18.4s, v19.4s\n" - "ZIP2 v22.4s, v16.4s, v17.4s\n" - "ZIP2 v23.4s, v18.4s, v19.4s\n" + "ZIP1 v20.4s, v16.4s, v17.4s\n" + "ZIP1 v21.4s, v18.4s, v19.4s\n" + "ZIP2 v22.4s, v16.4s, v17.4s\n" + "ZIP2 v23.4s, v18.4s, v19.4s\n" - "ZIP2 v16.4s, v1.4s, v5.4s\n" - "ZIP2 v17.4s, v3.4s, v7.4s\n" - "STP q20, q21, [%[outptr]], #32\n" // Fifth element + "ZIP2 v16.4s, v1.4s, v5.4s\n" + "ZIP2 v17.4s, v3.4s, v7.4s\n" + "STP q20, q21, [%[outptr]], #32\n" // Fifth element - "ZIP2 v18.4s, v9.4s, v13.4s\n" - "ZIP2 v19.4s, v11.4s, v15.4s\n" - "STP q22, q23, [%[outptr]], #32\n" // Sixth element + "ZIP2 v18.4s, v9.4s, v13.4s\n" + "ZIP2 v19.4s, v11.4s, v15.4s\n" + "STP q22, q23, [%[outptr]], #32\n" // Sixth element - "ZIP1 v20.4s, v16.4s, v17.4s\n" - "ZIP1 v21.4s, v18.4s, v19.4s\n" - "STP q20, q21, [%[outptr]], #32\n" // Seventh element + "ZIP1 v20.4s, v16.4s, v17.4s\n" + "ZIP1 v21.4s, v18.4s, v19.4s\n" + "STP q20, q21, [%[outptr]], #32\n" // Seventh element - "ZIP2 v22.4s, v16.4s, v17.4s\n" - "ZIP2 v23.4s, v18.4s, v19.4s\n" - "STP q22, q23, [%[outptr]], #32\n" // Eighth element - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) + "ZIP2 v22.4s, v16.4s, v17.4s\n" + "ZIP2 v23.4s, v18.4s, v19.4s\n" + "STP q22, q23, [%[outptr]], #32\n" // Eighth element + : [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), + [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5), [inptr6] "+r" (inptr6), [inptr7] "+r" (inptr7), [outptr] "+r" (outptr) : : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + ); } - for(; x > 0; x--) - { + for (;x>0;x--) { *outptr++ = *inptr0++; *outptr++ = *inptr1++; *outptr++ = *inptr2++; diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_16bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_16bit.hpp index fd6a253c6a..6e07064a0c 100644 --- a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_16bit.hpp +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_16bit.hpp @@ -31,105 +31,115 @@ template <> template inline void TransformImpl<6, 1, true, 4, 4>::Transform( - T *out, const T *const in, const int stride, - const int x0, const int xmax, const int k0, const int kmax) -{ - // Redirect to a 12 x uint16_t specialisation - TransformImpl<12, 1, true, 2, 2>::Transform( - reinterpret_cast(out), - reinterpret_cast(in), - stride * 2, x0 * 2, xmax * 2, k0, kmax); + T* out, const T* const in, const int stride, + const int x0, const int xmax, const int k0, const int kmax +) { + // Redirect to a 12 x uint16_t specialisation + TransformImpl<12, 1, true, 2, 2>::Transform( + reinterpret_cast(out), + reinterpret_cast(in), + stride*2, x0*2, xmax*2, k0, kmax + ); } // Generic 12x16-bit sized specialisation template <> template inline void TransformImpl<12, 1, true, 2, 2>::Transform( - T *out, const T *const in, const int stride, - const int x0, const int xmax, const int k0, const int kmax) -{ - // Redirect to a uint16_t specialisation - Transform( - reinterpret_cast(out), - reinterpret_cast(in), - stride, x0, xmax, k0, kmax); + T* out, const T* const in, const int stride, + const int x0, const int xmax, const int k0, const int kmax +) { + // Redirect to a uint16_t specialisation + Transform( + reinterpret_cast(out), + reinterpret_cast(in), + stride, x0, xmax, k0, kmax + ); } // Specialised 12 x uint16_t version template <> -inline void TransposeInterleaveCommon<12, uint16_t, uint16_t>::moveblock_1x1(const uint16_t *&in0, uint16_t *out) -{ - __asm volatile( - "LDR q0, [%[in0]]\n" - "STR q0, [%[out]]\n" - "LDR d1, [%[in0], #0x10]\n" - "STR d1, [%[out], #0x10]\n" - "ADD %x[in0], %x[in0], #0x18\n" ASM_PREFETCH("[%[in0], #192]") - : [in0] "+r"(in0), - [out] "+r"(out) - : - : "v0", "v1", "memory"); +inline void TransposeInterleaveCommon<12, uint16_t, uint16_t>::moveblock_1x1(const uint16_t *&in0, uint16_t *out) { + __asm volatile ( + "LDR q0, [%[in0]]\n" + "STR q0, [%[out]]\n" + "LDR d1, [%[in0], #0x10]\n" + "STR d1, [%[out], #0x10]\n" + "ADD %x[in0], %x[in0], #0x18\n" + ASM_PREFETCH("[%[in0], #192]") + : [in0] "+r" (in0), + [out] "+r" (out) + : + : "v0", "v1", "memory" + ); } template <> -inline void TransposeInterleaveCommon<12, uint16_t, uint16_t>::moveblock_1x2(const uint16_t *&in0, const uint16_t *&in1, uint16_t *out) -{ - __asm volatile( - "LDR q0, [%[in0]]\n" - "LDR d1, [%[in0], #0x10]\n" - "ADD %x[in0], %x[in0], #0x18\n" ASM_PREFETCH("[%[in0], #192]") +inline void TransposeInterleaveCommon<12, uint16_t, uint16_t>::moveblock_1x2(const uint16_t *&in0, const uint16_t *&in1, uint16_t *out) { + __asm volatile ( + "LDR q0, [%[in0]]\n" + "LDR d1, [%[in0], #0x10]\n" + "ADD %x[in0], %x[in0], #0x18\n" + ASM_PREFETCH("[%[in0], #192]") - "LDR x21, [%[in1]]\n" - "LDR q2, [%[in1], #0x08]\n" - "INS v1.d[1], x21\n" - "ADD %x[in1], %x[in1], #0x18\n" - "STP q0, q1, [%[out]]\n" - "STR q2, [%x[out], #0x20]\n" ASM_PREFETCH("[%[in1], #192]") - : [in0] "+r"(in0), - [in1] "+r"(in1), - [out] "+r"(out) - : - : "x21", "v0", "v1", "v2", "memory"); + "LDR x21, [%[in1]]\n" + "LDR q2, [%[in1], #0x08]\n" + "INS v1.d[1], x21\n" + "ADD %x[in1], %x[in1], #0x18\n" + "STP q0, q1, [%[out]]\n" + "STR q2, [%x[out], #0x20]\n" + ASM_PREFETCH("[%[in1], #192]") + : [in0] "+r" (in0), + [in1] "+r" (in1), + [out] "+r" (out) + : + : "x21", "v0", "v1", "v2", "memory" + ); } template <> -inline void TransposeInterleaveCommon<12, uint16_t, uint16_t>::moveblock_1x4(const uint16_t *&in0, const uint16_t *&in1, const uint16_t *&in2, const uint16_t *&in3, uint16_t *out) -{ - __asm __volatile( - "LDR q0, [%x[in0]], #0x10\n" - "STR q0, [%x[out]]\n" - "LDR d1, [%x[in0]], #0x08\n" ASM_PREFETCH("[%[in0], #192]") - "STR d1, [%x[out], #0x10]\n" +inline void TransposeInterleaveCommon<12, uint16_t, uint16_t>::moveblock_1x4(const uint16_t *&in0, const uint16_t *&in1, const uint16_t *&in2, const uint16_t *&in3, uint16_t *out) { + __asm __volatile ( + "LDR q0, [%x[in0]], #0x10\n" + "STR q0, [%x[out]]\n" + "LDR d1, [%x[in0]], #0x08\n" + ASM_PREFETCH("[%[in0], #192]") + "STR d1, [%x[out], #0x10]\n" - "LDR q0, [%x[in1]], #0x10\n" - "STR q0, [%x[out], #0x18]\n" - "LDR d1, [%x[in1]], #0x08\n" ASM_PREFETCH("[%[in1], #192]") - "STR d1, [%x[out], #0x28]\n" + "LDR q0, [%x[in1]], #0x10\n" + "STR q0, [%x[out], #0x18]\n" + "LDR d1, [%x[in1]], #0x08\n" + ASM_PREFETCH("[%[in1], #192]") + "STR d1, [%x[out], #0x28]\n" - "LDR q0, [%x[in2]], #0x10\n" - "STR q0, [%x[out], #0x30]\n" - "LDR d1, [%x[in2]], #0x08\n" ASM_PREFETCH("[%[in2], #192]") - "STR d1, [%x[out], #0x40]\n" + "LDR q0, [%x[in2]], #0x10\n" + "STR q0, [%x[out], #0x30]\n" + "LDR d1, [%x[in2]], #0x08\n" + ASM_PREFETCH("[%[in2], #192]") + "STR d1, [%x[out], #0x40]\n" - "LDR q0, [%x[in3]], #0x10\n" - "STR q0, [%x[out], #0x48]\n" - "LDR d1, [%x[in3]], #0x08\n" ASM_PREFETCH("[%[in3], #192]") "STR d1, [%x[out], #0x58]\n" - : [in0] "+r"(in0), - [in1] "+r"(in1), - [in2] "+r"(in2), - [in3] "+r"(in3), - [out] "+r"(out) - : - : "v0", "v1", "memory"); + "LDR q0, [%x[in3]], #0x10\n" + "STR q0, [%x[out], #0x48]\n" + "LDR d1, [%x[in3]], #0x08\n" + ASM_PREFETCH("[%[in3], #192]") + "STR d1, [%x[out], #0x58]\n" + : [in0] "+r" (in0), + [in1] "+r" (in1), + [in2] "+r" (in2), + [in3] "+r" (in3), + [out] "+r" (out) + : + : "v0", "v1", "memory" + ); } template <> template <> inline void TransformImpl<12, 1, true, 2, 2>::Transform( - uint16_t *out, const uint16_t *const in, const int stride, - const int x0, const int xmax, const int k0, const int kmax) -{ - TransposeInterleaveCommon<12, uint16_t, uint16_t>::Transform(out, in, stride, x0, xmax, k0, kmax); + uint16_t* out, const uint16_t* const in, const int stride, + const int x0, const int xmax, const int k0, const int kmax +) { + TransposeInterleaveCommon<12, uint16_t, uint16_t>::Transform(out, in, stride, x0, xmax, k0, kmax); } #endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_half_to_float.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_half_to_float.hpp index b79f32fb8b..2f90c18ebd 100644 --- a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_half_to_float.hpp +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_half_to_float.hpp @@ -28,86 +28,93 @@ #include "transpose_interleave_common.hpp" template <> -inline void TransposeInterleaveCommon<12, __fp16, float>::moveblock_1x1(const __fp16 *&in0, float *out) -{ - __asm __volatile( +inline void TransposeInterleaveCommon<12, __fp16, float>::moveblock_1x1(const __fp16 *&in0, float *out) { + __asm __volatile ( "LDR q0, [%[in0]], #16\n" - "FCVTL2 v1.4s, v0.8h\n" - "FCVTL v0.4s, v0.4h\n" - "STP q0, q1, [%[out]]\n" ASM_PREFETCH("[%[in0], #192]") + "FCVTL2 v1.4s, v0.8h\n" + "FCVTL v0.4s, v0.4h\n" + "STP q0, q1, [%[out]]\n" + ASM_PREFETCH("[%[in0], #192]") "LDR d2, [%[in0]], #8\n" - "FCVTL v2.4s, v2.4h\n" + "FCVTL v2.4s, v2.4h\n" "STR q2, [%[out], #32]\n" - : [in0] "+r"(in0), [out] "+r"(out) - : - : "v0", "v1", "v2", "memory"); + : [in0] "+r" (in0), [out] "+r" (out) + : + : "v0", "v1", "v2", "memory" + ); } template <> -inline void TransposeInterleaveCommon<12, __fp16, float>::moveblock_1x2(const __fp16 *&in0, const __fp16 *&in1, float *out) -{ - __asm __volatile( +inline void TransposeInterleaveCommon<12, __fp16, float>::moveblock_1x2(const __fp16 *&in0, const __fp16 *&in1, float *out) { + __asm __volatile ( "LDR q0, [%[in0]], #16\n" - "FCVTL2 v1.4s, v0.8h\n" - "FCVTL v0.4s, v0.4h\n" - "STP q0, q1, [%[out]]\n" ASM_PREFETCH("[%[in0], #192]") + "FCVTL2 v1.4s, v0.8h\n" + "FCVTL v0.4s, v0.4h\n" + "STP q0, q1, [%[out]]\n" + ASM_PREFETCH("[%[in0], #192]") "LDR d2, [%[in0]], #8\n" - "FCVTL v2.4s, v2.4h\n" - "LDR q3, [%[in1]], #16\n" - "FCVTL2 v4.4s, v3.8h\n" - "FCVTL v3.4s, v3.4h\n" - "STP q2, q3, [%[out], #32]\n" ASM_PREFETCH("[%[in1], #192]") - "LDR d5, [%[in1]], #16\n" - "FCVTL v5.4s, v5.4h\n" + "FCVTL v2.4s, v2.4h\n" + "LDR q3, [%[in1]], #16\n" + "FCVTL2 v4.4s, v3.8h\n" + "FCVTL v3.4s, v3.4h\n" + "STP q2, q3, [%[out], #32]\n" + ASM_PREFETCH("[%[in1], #192]") + "LDR d5, [%[in1]], #16\n" + "FCVTL v5.4s, v5.4h\n" "STP q4, q5, [%[out], #64]\n" - : [in0] "+r"(in0), [in1] "+r"(in1), [out] "+r"(out) - : - : "v0", "v1", "v2", "v3", "v4", "v5", "memory"); + : [in0] "+r" (in0), [in1] "+r" (in1), [out] "+r" (out) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "memory" + ); } template <> -inline void TransposeInterleaveCommon<12, __fp16, float>::moveblock_1x4(const __fp16 *&in0, const __fp16 *&in1, const __fp16 *&in2, const __fp16 *&in3, float *out) -{ - __asm __volatile( +inline void TransposeInterleaveCommon<12, __fp16, float>::moveblock_1x4(const __fp16 *&in0, const __fp16 *&in1, const __fp16 *&in2, const __fp16 *&in3, float *out) { + __asm __volatile ( "LDR q0, [%[in0]], #16\n" - "FCVTL2 v1.4s, v0.8h\n" - "FCVTL v0.4s, v0.4h\n" + "FCVTL2 v1.4s, v0.8h\n" + "FCVTL v0.4s, v0.4h\n" "STP q0, q1, [%[out]]\n" - "LDR d2, [%[in0]], #8\n" ASM_PREFETCH("[%[in0], #192]") - "FCVTL v2.4s, v2.4h\n" - "LDR q3, [%[in1]], #16\n" - "FCVTL2 v4.4s, v3.8h\n" - "FCVTL v3.4s, v3.4h\n" + "LDR d2, [%[in0]], #8\n" + ASM_PREFETCH("[%[in0], #192]") + "FCVTL v2.4s, v2.4h\n" + "LDR q3, [%[in1]], #16\n" + "FCVTL2 v4.4s, v3.8h\n" + "FCVTL v3.4s, v3.4h\n" "STP q2, q3, [%[out], #32]\n" - "LDR d5, [%[in1]], #8\n" - "FCVTL v5.4s, v5.4h\n" ASM_PREFETCH("[%[in1], #192]") + "LDR d5, [%[in1]], #8\n" + "FCVTL v5.4s, v5.4h\n" + ASM_PREFETCH("[%[in1], #192]") "STP q4, q5, [%[out], #64]\n" - "LDR q6, [%[in2]], #16\n" - "FCVTL2 v7.4s, v6.8h\n" - "FCVTL v6.4s, v6.4h\n" + "LDR q6, [%[in2]], #16\n" + "FCVTL2 v7.4s, v6.8h\n" + "FCVTL v6.4s, v6.4h\n" "STP q6, q7, [%[out], #96]\n" - "LDR d8, [%[in2]], #8\n" - "FCVTL v8.4s, v8.4h\n" ASM_PREFETCH("[%[in2], #192]") - "LDR q9, [%[in3]], #16\n" - "FCVTL2 v10.4s, v9.8h\n" - "FCVTL v9.4s, v9.4h\n" + "LDR d8, [%[in2]], #8\n" + "FCVTL v8.4s, v8.4h\n" + ASM_PREFETCH("[%[in2], #192]") + "LDR q9, [%[in3]], #16\n" + "FCVTL2 v10.4s, v9.8h\n" + "FCVTL v9.4s, v9.4h\n" "STP q8, q9, [%[out], #128]\n" - "LDR d11, [%[in3]], #8\n" - "FCVTL v11.4s, v11.4h\n" - "STP q10, q11, [%[out], #160]\n" ASM_PREFETCH("[%[in3], #192]") + "LDR d11, [%[in3]], #8\n" + "FCVTL v11.4s, v11.4h\n" + "STP q10, q11, [%[out], #160]\n" + ASM_PREFETCH("[%[in3], #192]") - : [in0] "+r"(in0), [in1] "+r"(in1), [in2] "+r"(in2), [in3] "+r"(in3), [out] "+r"(out) - : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "memory"); + : [in0] "+r" (in0), [in1] "+r" (in1), [in2] "+r" (in2), [in3] "+r" (in3), [out] "+r" (out) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "memory" + ); } template <> template <> inline void TransformImpl<12, 1, true, 4, 2>::Transform( - float *out, const __fp16 *const in, const int stride, - const int x0, const int xmax, const int k0, const int kmax) -{ - TransposeInterleaveCommon<12, __fp16, float>::Transform(out, in, stride, x0, xmax, k0, kmax); + float* out, const __fp16* const in, const int stride, + const int x0, const int xmax, const int k0, const int kmax +) { + TransposeInterleaveCommon<12, __fp16, float>::Transform(out, in, stride, x0, xmax, k0, kmax); } #endif // __aarch64__ && __ARM_FP16_ARGS diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_24way_16bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_24way_16bit.hpp index 5434599f03..b6565baa23 100644 --- a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_24way_16bit.hpp +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_24way_16bit.hpp @@ -31,91 +31,100 @@ template <> template inline void TransformImpl<12, 1, true, 4, 4>::Transform( - T *out, const T *const in, const int stride, - const int x0, const int xmax, const int k0, const int kmax) -{ - // Redirect to a 24 x uint16_t specialisation - TransformImpl<24, 1, true, 2, 2>::Transform( - reinterpret_cast(out), - reinterpret_cast(in), - stride * 2, x0 * 2, xmax * 2, k0, kmax); + T* out, const T* const in, const int stride, + const int x0, const int xmax, const int k0, const int kmax +) { + // Redirect to a 24 x uint16_t specialisation + TransformImpl<24, 1, true, 2, 2>::Transform( + reinterpret_cast(out), + reinterpret_cast(in), + stride*2, x0*2, xmax*2, k0, kmax + ); } // Generic 24x16-bit sized specialisation template <> template inline void TransformImpl<24, 1, true, 2, 2>::Transform( - T *out, const T *const in, const int stride, - const int x0, const int xmax, const int k0, const int kmax) -{ - // Redirect to a uint16_t specialisation - Transform( - reinterpret_cast(out), - reinterpret_cast(in), - stride, x0, xmax, k0, kmax); + T* out, const T* const in, const int stride, + const int x0, const int xmax, const int k0, const int kmax +) { + // Redirect to a uint16_t specialisation + Transform( + reinterpret_cast(out), + reinterpret_cast(in), + stride, x0, xmax, k0, kmax + ); } // Specialised 24 x uint16_t version template <> -inline void TransposeInterleaveCommon<24, uint16_t, uint16_t>::moveblock_1x1(const uint16_t *&in0, uint16_t *out) -{ - __asm __volatile( +inline void TransposeInterleaveCommon<24, uint16_t, uint16_t>::moveblock_1x1(const uint16_t *&in0, uint16_t *out) { + __asm __volatile ( "LDP q0, q1, [%[in0]], #32\n" - "STP q0, q1, [%[out]]\n" ASM_PREFETCH("[%[in0], #192]") + "STP q0, q1, [%[out]]\n" + ASM_PREFETCH("[%[in0], #192]") "LDR q2, [%[in0]], #16\n" "STR q2, [%[out], #32]\n" - : [in0] "+r"(in0), [out] "+r"(out) - : - : "v0", "v1", "v2", "memory"); + : [in0] "+r" (in0), [out] "+r" (out) + : + : "v0", "v1", "v2", "memory" + ); } template <> -inline void TransposeInterleaveCommon<24, uint16_t, uint16_t>::moveblock_1x2(const uint16_t *&in0, const uint16_t *&in1, uint16_t *out) -{ - __asm __volatile( +inline void TransposeInterleaveCommon<24, uint16_t, uint16_t>::moveblock_1x2(const uint16_t *&in0, const uint16_t *&in1,uint16_t *out) { + __asm __volatile ( "LDP q0, q1, [%[in0]], #32\n" - "STP q0, q1, [%[out]]\n" ASM_PREFETCH("[%[in0], #192]") + "STP q0, q1, [%[out]]\n" + ASM_PREFETCH("[%[in0], #192]") "LDR q2, [%[in0]], #16\n" - "LDP q3, q4, [%[in1]], #32\n" - "STP q2, q3, [%[out], #32]\n" ASM_PREFETCH("[%[in1], #192]") - "LDR q5, [%[in1]], #16\n" + "LDP q3, q4, [%[in1]], #32\n" + "STP q2, q3, [%[out], #32]\n" + ASM_PREFETCH("[%[in1], #192]") + "LDR q5, [%[in1]], #16\n" "STP q4, q5, [%[out], #64]\n" - : [in0] "+r"(in0), [in1] "+r"(in1), [out] "+r"(out) - : - : "v0", "v1", "v2", "v3", "v4", "v5", "memory"); + : [in0] "+r" (in0), [in1] "+r" (in1), [out] "+r" (out) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "memory" + ); } template <> -inline void TransposeInterleaveCommon<24, uint16_t, uint16_t>::moveblock_1x4(const uint16_t *&in0, const uint16_t *&in1, const uint16_t *&in2, const uint16_t *&in3, uint16_t *out) -{ - __asm __volatile( +inline void TransposeInterleaveCommon<24, uint16_t, uint16_t>::moveblock_1x4(const uint16_t *&in0, const uint16_t *&in1, const uint16_t *&in2, const uint16_t *&in3, uint16_t *out) { + __asm __volatile ( "LDP q0, q1, [%[in0]], #32\n" "STP q0, q1, [%[out]]\n" - "LDR q2, [%[in0]], #16\n" ASM_PREFETCH("[%[in0], #192]") - "LDP q3, q4, [%[in1]], #32\n" + "LDR q2, [%[in0]], #16\n" + ASM_PREFETCH("[%[in0], #192]") + "LDP q3, q4, [%[in1]], #32\n" "STP q2, q3, [%[out], #32]\n" - "LDR q5, [%[in1]], #16\n" ASM_PREFETCH("[%[in1], #192]") + "LDR q5, [%[in1]], #16\n" + ASM_PREFETCH("[%[in1], #192]") "STP q4, q5, [%[out], #64]\n" - "LDP q6, q7, [%[in2]], #32\n" + "LDP q6, q7, [%[in2]], #32\n" "STP q6, q7, [%[out], #96]\n" - "LDR q8, [%[in2]], #16\n" ASM_PREFETCH("[%[in2], #192]") - "LDP q9, q10, [%[in3]], #32\n" + "LDR q8, [%[in2]], #16\n" + ASM_PREFETCH("[%[in2], #192]") + "LDP q9, q10, [%[in3]], #32\n" "STP q8, q9, [%[out], #128]\n" - "LDR q11, [%[in3]], #16\n" - "STP q10, q11, [%[out], #160]\n" ASM_PREFETCH("[%[in3], #192]") + "LDR q11, [%[in3]], #16\n" + "STP q10, q11, [%[out], #160]\n" + ASM_PREFETCH("[%[in3], #192]") - : [in0] "+r"(in0), [in1] "+r"(in1), [in2] "+r"(in2), [in3] "+r"(in3), [out] "+r"(out) - : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "memory"); + : [in0] "+r" (in0), [in1] "+r" (in1), [in2] "+r" (in2), [in3] "+r" (in3), [out] "+r" (out) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "memory" + ); } template <> template <> inline void TransformImpl<24, 1, true, 2, 2>::Transform( - uint16_t *out, const uint16_t *const in, const int stride, - const int x0, const int xmax, const int k0, const int kmax) -{ - TransposeInterleaveCommon<24, uint16_t, uint16_t>::Transform(out, in, stride, x0, xmax, k0, kmax); + uint16_t* out, const uint16_t* const in, const int stride, + const int x0, const int xmax, const int k0, const int kmax +) { + TransposeInterleaveCommon<24, uint16_t, uint16_t>::Transform(out, in, stride, x0, xmax, k0, kmax); } -#endif // __arch64__ +#endif // __arch64__ diff --git a/src/core/NEON/kernels/arm_gemm/transforms/transpose_interleave_common.hpp b/src/core/NEON/kernels/arm_gemm/transforms/transpose_interleave_common.hpp index 3218ca1aac..63e85c155a 100644 --- a/src/core/NEON/kernels/arm_gemm/transforms/transpose_interleave_common.hpp +++ b/src/core/NEON/kernels/arm_gemm/transforms/transpose_interleave_common.hpp @@ -24,137 +24,117 @@ #pragma once template -struct TransposeInterleaveCommon -{ - // Override the moveblock_1xY methods to improve performance - static inline void moveblock_1x1(const TIn *&in0, TOut *out) - { - for(unsigned int i = 0; i < IntBy; i++) - { - *out++ = static_cast(*in0++); - } +struct TransposeInterleaveCommon { + // Override the moveblock_1xY methods to improve performance + static inline void moveblock_1x1(const TIn *&in0, TOut *out) { + for (unsigned int i = 0; i < IntBy; i++) { + *out++ = static_cast(*in0++); } + } - static inline void moveblock_1x2(const TIn *&in0, const TIn *&in1, TOut *out) - { - for(unsigned int i = 0; i < IntBy; i++) - { - *out++ = static_cast(*in0++); - } - for(unsigned int i = 0; i < IntBy; i++) - { - *out++ = static_cast(*in1++); - } + static inline void moveblock_1x2(const TIn *&in0, const TIn *&in1, TOut *out) { + for (unsigned int i = 0; i < IntBy; i++) { + *out++ = static_cast(*in0++); + } + for (unsigned int i = 0; i < IntBy; i++) { + *out++ = static_cast(*in1++); } + } - static inline void moveblock_1x4(const TIn *&in0, const TIn *&in1, const TIn *&in2, const TIn *&in3, TOut *out) - { - for(unsigned int i = 0; i < IntBy; i++) - { - *out++ = static_cast(*in0++); - } - for(unsigned int i = 0; i < IntBy; i++) - { - *out++ = static_cast(*in1++); - } - for(unsigned int i = 0; i < IntBy; i++) - { - *out++ = static_cast(*in2++); - } - for(unsigned int i = 0; i < IntBy; i++) - { - *out++ = static_cast(*in3++); + static inline void moveblock_1x4(const TIn *&in0, const TIn *&in1, const TIn *&in2, const TIn *&in3, TOut *out) { + for (unsigned int i = 0; i < IntBy; i++) { + *out++ = static_cast(*in0++); + } + for (unsigned int i = 0; i < IntBy; i++) { + *out++ = static_cast(*in1++); + } + for (unsigned int i = 0; i < IntBy; i++) { + *out++ = static_cast(*in2++); + } + for (unsigned int i = 0; i < IntBy; i++) { + *out++ = static_cast(*in3++); + } + } + + static inline void Transform(TOut *out, const TIn *in, const int stride, const int x0, const int xmax, const int k0, const int kmax) { + const auto ldin = stride; + + TOut *outarray = out; + const TIn *inarray = in; + TOut *outptr_base = outarray; + const TIn *inptr_base = inarray + x0 + (k0 * ldin); + int ldout = (kmax - k0) * IntBy; + + int k=(kmax-k0); + for ( ; k>3; k-=4) { + TOut *outptr = outptr_base; + const TIn *inptr = inptr_base; + const TIn *inptr1 = inptr + ldin; + const TIn *inptr2 = inptr1 + ldin; + const TIn *inptr3 = inptr2 + ldin; + + prefetch_3x(inptr); + prefetch_3x(inptr1); + prefetch_3x(inptr2); + prefetch_3x(inptr3); + + outptr_base += IntBy * 4; + inptr_base += ldin * 4; + + for (int x = (xmax-x0) / IntBy; x > 0 ; x--) { + moveblock_1x4(inptr, inptr1, inptr2, inptr3, outptr); + outptr += ldout; } } - static inline void Transform(TOut *out, const TIn *in, const int stride, const int x0, const int xmax, const int k0, const int kmax) - { - const auto ldin = stride; - - TOut *outarray = out; - const TIn *inarray = in; - TOut *outptr_base = outarray; - const TIn *inptr_base = inarray + x0 + (k0 * ldin); - int ldout = (kmax - k0) * IntBy; - - int k = (kmax - k0); - for(; k > 3; k -= 4) - { - TOut *outptr = outptr_base; - const TIn *inptr = inptr_base; - const TIn *inptr1 = inptr + ldin; - const TIn *inptr2 = inptr1 + ldin; - const TIn *inptr3 = inptr2 + ldin; - - prefetch_3x(inptr); - prefetch_3x(inptr1); - prefetch_3x(inptr2); - prefetch_3x(inptr3); - - outptr_base += IntBy * 4; - inptr_base += ldin * 4; - - for(int x = (xmax - x0) / IntBy; x > 0; x--) - { - moveblock_1x4(inptr, inptr1, inptr2, inptr3, outptr); - outptr += ldout; + if (k) { + TOut *outptr = outptr_base; + const TIn *inptr = inptr_base; + const TIn *inptr1 = inptr + ldin; + const TIn *inptr2 = inptr1 + ldin; + + prefetch_3x(inptr); + prefetch_3x(inptr1); + prefetch_3x(inptr2); + + for (int x = (xmax-x0) / IntBy; x > 0 ; x--) { + switch(k) { + case 3: + moveblock_1x2(inptr, inptr1, outptr); + moveblock_1x1(inptr2, outptr + IntBy * 2); + break; + + case 2: + moveblock_1x2(inptr, inptr1, outptr); + break; + + case 1: + moveblock_1x1(inptr, outptr); + break; + + default: + UNREACHABLE("Impossible."); } - } - if(k) - { - TOut *outptr = outptr_base; - const TIn *inptr = inptr_base; - const TIn *inptr1 = inptr + ldin; - const TIn *inptr2 = inptr1 + ldin; - - prefetch_3x(inptr); - prefetch_3x(inptr1); - prefetch_3x(inptr2); - - for(int x = (xmax - x0) / IntBy; x > 0; x--) - { - switch(k) - { - case 3: - moveblock_1x2(inptr, inptr1, outptr); - moveblock_1x1(inptr2, outptr + IntBy * 2); - break; - - case 2: - moveblock_1x2(inptr, inptr1, outptr); - break; - - case 1: - moveblock_1x1(inptr, outptr); - break; - - default: - UNREACHABLE("Impossible."); - } - - outptr += ldout; - } + outptr += ldout; } + } + + // Cope with ragged X cases + const unsigned int overflow = (xmax - x0) % IntBy; + if (overflow) { + const TIn *inptr_base = inarray + (xmax - overflow) + (k0 * ldin); + TOut *outptr = outarray + ((xmax - x0) / IntBy) * ldout; + + for (int k=(kmax-k0); k>0; k--) { + const TIn *inptr = inptr_base; + inptr_base += ldin; - // Cope with ragged X cases - const unsigned int overflow = (xmax - x0) % IntBy; - if(overflow) - { - const TIn *inptr_base = inarray + (xmax - overflow) + (k0 * ldin); - TOut *outptr = outarray + ((xmax - x0) / IntBy) * ldout; - - for(int k = (kmax - k0); k > 0; k--) - { - const TIn *inptr = inptr_base; - inptr_base += ldin; - - for(unsigned int x = 0; x < IntBy; x++) - { - TOut val = (x < overflow) ? static_cast(*inptr++) : static_cast(0); - *outptr++ = val; - } + for (unsigned int x=0; x < IntBy; x++) { + TOut val = (x < overflow) ? static_cast(*inptr++) : static_cast(0); + *outptr++ = val; } } } +} }; diff --git a/src/core/NEON/kernels/arm_gemm/utils.hpp b/src/core/NEON/kernels/arm_gemm/utils.hpp index 6c5b92ae8f..c1977d5f3e 100644 --- a/src/core/NEON/kernels/arm_gemm/utils.hpp +++ b/src/core/NEON/kernels/arm_gemm/utils.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -25,27 +25,22 @@ #pragma once // Macro for unreachable code (e.g. impossible default cases on switch) -#define UNREACHABLE(why) __builtin_unreachable() +#define UNREACHABLE(why) __builtin_unreachable() // Paranoid option for the above with assert // #define UNREACHABLE(why) assert(0 && why) -inline int iceildiv(const int a, const int b) -{ - return (a + b - 1) / b; +inline int iceildiv(const int a, const int b) { + return (a + b - 1) / b; } template -inline T roundup(const T a, const T b) -{ - T rem = a % b; +inline T roundup(const T a, const T b) { + T rem = a % b; - if(rem) - { - return a + b - rem; - } - else - { - return a; - } + if (rem) { + return a + b - rem; + } else { + return a; + } } -- cgit v1.2.1