diff options
Diffstat (limited to 'arm_compute')
-rw-r--r-- | arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp | 7 | ||||
-rw-r--r-- | arm_compute/core/NEON/kernels/assembly/gemm_common.hpp | 18 | ||||
-rw-r--r-- | arm_compute/runtime/NEON/AssemblyHelper.h | 32 |
3 files changed, 36 insertions, 21 deletions
diff --git a/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp b/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp index d6c9931a21..0a541c6db9 100644 --- a/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp +++ b/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp @@ -34,6 +34,9 @@ template<typename Top, typename Tret> using UniqueGemmCommon = std::unique_ptr<GemmCommon<Top, Tret> >; template<typename Top, typename Tret> -UniqueGemmCommon<Top, Tret> gemm(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K, const bool trA, const bool trB, const Tret alpha, const Tret beta, const int maxthreads, const bool pretransposed_hint); - +UniqueGemmCommon<Top, Tret> gemm(const CPUInfo &ci, + const unsigned int M, const unsigned int N, const unsigned int K, + const unsigned int nbatches, const unsigned int nmulti, + const bool trA, const bool trB, const Tret alpha, const Tret beta, + const int maxthreads, const bool pretransposed_hint); } // namespace arm_gemm diff --git a/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp b/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp index 7f47abcbb9..3919c339bf 100644 --- a/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp +++ b/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp @@ -39,23 +39,35 @@ class GemmCommon { protected: const To *_Aptr=nullptr; int _lda=0; + int _A_batch_stride=0; + int _A_multi_stride=0; const To *_Bptr=nullptr; int _ldb=0; + int _B_multi_stride=0; Tr *_Cptr=nullptr; int _ldc=0; + int _C_batch_stride=0; + int _C_multi_stride=0; public: /* Pass in the pointers to the arrays to be operated on and their * strides. This has a default implementation that just captures them * all in protected members. If B is pretransposed (see below) then the * settings for B here are ignored. */ - virtual void set_arrays(const To *A, const int lda, const To *B, const int ldb, Tr *C, const int ldc) { + virtual void set_arrays(const To *A, const int lda, const int A_batch_stride, const int A_multi_stride, + const To *B, const int ldb, /* batches share B */ const int B_multi_stride, + Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride) { _Aptr = A; _lda = lda; + _A_batch_stride = A_batch_stride; + _A_multi_stride = A_multi_stride; _Bptr = B; _ldb = ldb; + _B_multi_stride = B_multi_stride; _Cptr = C; _ldc = ldc; + _C_batch_stride = C_batch_stride; + _C_multi_stride = C_multi_stride; } /* For threading, we divide the work into some number of units and work @@ -95,7 +107,9 @@ public: /* Total number of bytes of space needed for pretransposed arrays. */ virtual size_t get_B_pretransposed_array_size() const { return 0; } /* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */ - virtual void pretranspose_B_array(void *buffer, const To *, const int) { }; + virtual void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) { }; + /* Set pretransposed data - the void * passed in must previously have been passed to pretranspose_B_array() for the same or a similar GEMM. */ + virtual void set_pretransposed_B_data(void *buffer) { } // Destructor virtual ~GemmCommon() { } diff --git a/arm_compute/runtime/NEON/AssemblyHelper.h b/arm_compute/runtime/NEON/AssemblyHelper.h index 40f28587c2..2b4f35f2e1 100644 --- a/arm_compute/runtime/NEON/AssemblyHelper.h +++ b/arm_compute/runtime/NEON/AssemblyHelper.h @@ -82,24 +82,19 @@ public: const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput); const int ldd = _d->info()->strides_in_bytes().y() / sizeof(TypeOutput); - // Configure kernel window - Window window = calculate_max_window(*_d->info()); + const int batch_stride_a = _a->info()->strides_in_bytes().z() / sizeof(TypeInput); + const int batch_stride_d = _d->info()->strides_in_bytes().z() / sizeof(TypeOutput); + + const int multi_stride_a = _a->info()->strides_in_bytes()[3] / sizeof(TypeInput); + const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput); + const int multi_stride_d = _d->info()->strides_in_bytes()[3] / sizeof(TypeOutput); + + const auto in0_ptr = reinterpret_cast<const TypeInput *>(_a->buffer()); const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer()); + auto out_ptr = reinterpret_cast<TypeOutput *>(_d->buffer()); - // Only iterate over batches - Window win(window); - win.set(0, Window::Dimension(0, 1, 1)); - win.set(1, Window::Dimension(0, 1, 1)); - Iterator in0(_a, window); - Iterator out(_d, window); - execute_window_loop(win, [&](const Coordinates &) - { - const auto in0_ptr = reinterpret_cast<const TypeInput *>(in0.ptr()); - auto out_ptr = reinterpret_cast<TypeOutput *>(out.ptr()); - _gemm_kernel_asm->set_arrays(in0_ptr, lda, in1_ptr, ldb, out_ptr, ldd); - NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX); - }, - in0, out); + _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, in1_ptr, ldb, multi_stride_b, out_ptr, ldd, batch_stride_d, multi_stride_d); + NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX); } }; @@ -146,10 +141,13 @@ inline bool setup_assembly_kernel(const ITensor *a, const ITensor *b, ITensor *d const int M = d->info()->tensor_shape().y(); const int N = d->info()->tensor_shape().x(); const int K = a->info()->tensor_shape().x(); + const int batches = a->info()->tensor_shape().total_size_upper(2); + const int multis = b->info()->tensor_shape().z(); unsigned int num_threads = NEScheduler::get().num_threads(); + // unique_ptr to a Gemm object std::unique_ptr<typename T::AssemblyGemm> - asm_gemm(arm_gemm::gemm<typename T::TypeOperator, typename T::TypeResult>(ci, M, N, K, false, false, alpha, beta, num_threads, false)); + asm_gemm(arm_gemm::gemm<typename T::TypeOperator, typename T::TypeResult>(ci, M, N, K, batches, multis, false, false, alpha, beta, num_threads, false)); // arm_compute wrapper for the Gemm object (see above) std::unique_ptr<NEGEMMAssemblyWrapper<typename T::AssemblyGemm>> acl_gemm_wrapper = support::cpp14::make_unique<NEGEMMAssemblyWrapper<typename T::AssemblyGemm>>(); |