diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2018-05-03 13:44:35 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:51:17 +0000 |
commit | 932b561159cd6a8c9230bbd0343790c85755846e (patch) | |
tree | f5345af71844e8f78873258bddec2cb37b39b8f8 /arm_compute/runtime/NEON/AssemblyHelper.h | |
parent | 563494c2f447e201e88e6d7133a41e12971777eb (diff) | |
download | ComputeLibrary-932b561159cd6a8c9230bbd0343790c85755846e.tar.gz |
COMPMID-959: Perform pretranspose if allowed on NEON assembly
Change-Id: I281699ce7270aec1317c47b5a13799954cf6c9e8
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/130010
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Pablo Tello <pablo.tello@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'arm_compute/runtime/NEON/AssemblyHelper.h')
-rw-r--r-- | arm_compute/runtime/NEON/AssemblyHelper.h | 48 |
1 files changed, 34 insertions, 14 deletions
diff --git a/arm_compute/runtime/NEON/AssemblyHelper.h b/arm_compute/runtime/NEON/AssemblyHelper.h index 2b4f35f2e1..ee09ef531e 100644 --- a/arm_compute/runtime/NEON/AssemblyHelper.h +++ b/arm_compute/runtime/NEON/AssemblyHelper.h @@ -51,7 +51,7 @@ public: using TypeResult = TypeOutput; /** Default constructor. */ AssemblyKernelGlue() - : _gemm_kernel_asm(nullptr), _optimised_kernel(nullptr), _a(nullptr), _b(nullptr), _d(nullptr) + : _gemm_kernel_asm(nullptr), _optimised_kernel(nullptr), _a(nullptr), _b(nullptr), _d(nullptr), _pretranspose(nullptr) { } /** Assembly Gemm */ @@ -72,6 +72,8 @@ public: const ITensor *_b; /** Output */ ITensor *_d; + /** Pre-transpose tensor */ + ITensor *_pretranspose; /** Configures the arrays pointers and strides in the assembly kernel and executes the assembly kernel. * The call to set_arrays is needed to deal with the input sizes containing batches (dims > 2) @@ -94,6 +96,12 @@ public: auto out_ptr = reinterpret_cast<TypeOutput *>(_d->buffer()); _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, in1_ptr, ldb, multi_stride_b, out_ptr, ldd, batch_stride_d, multi_stride_d); + if(_gemm_kernel_asm->B_pretranspose_required()) + { + ARM_COMPUTE_ERROR_ON(_pretranspose == nullptr || _pretranspose->buffer() == nullptr); + _gemm_kernel_asm->pretranspose_B_array(reinterpret_cast<void *>(_pretranspose->buffer()), in1_ptr, ldb, multi_stride_b); + } + NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX); } }; @@ -113,8 +121,9 @@ using AssemblyKernelGlueS8S32 = AssemblyKernelGlue<int8_t, int32_t>; * @param[in] alignment Workspace memory alignment. * @param[in] num_threads Number of workspace threads. */ -inline void allocate_workspace(size_t workspace_size, Tensor &workspace, MemoryGroup &memory_group, size_t alignment, unsigned int num_threads) +inline void allocate_workspace(size_t workspace_size, Tensor &workspace, MemoryGroup *memory_group, size_t alignment, unsigned int num_threads) { + ARM_COMPUTE_UNUSED(memory_group); ARM_COMPUTE_ERROR_ON_MSG(workspace_size == 0, "size cannot be 0"); workspace.allocator()->init(TensorInfo(TensorShape{ (workspace_size + alignment - 1) * num_threads }, 1, DataType::S8)); workspace.allocator()->allocate(); @@ -122,20 +131,22 @@ inline void allocate_workspace(size_t workspace_size, Tensor &workspace, MemoryG /** Create a wrapper kernel. * - * @param[in] a Input tensor A. - * @param[in] b Input tensor B. - * @param[out] d Output tensor. - * @param[in] alpha Alpha value. - * @param[in] beta Beta value. - * @param[out] workspace Workspace tensor - * @param[in] memory_group Tensor memory group. - * @param[out] asm_glue Assembly glue kernel. + * @param[in] a Input tensor A. + * @param[in] b Input tensor B. + * @param[out] d Output tensor. + * @param[in] alpha Alpha value. + * @param[in] beta Beta value. + * @param[in] pretranspose_hint Pre-transpose hint in case matrix b should be pre-transposed + * @param[out] workspace Workspace tensor + * @param[out] B_pretranspose Tensor to hold the pre-transposed B + * @param[in] memory_group Tensor memory group. + * @param[out] asm_glue Assembly glue kernel. * * @return the wrapper kernel. */ template <typename T> -inline bool setup_assembly_kernel(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, - Tensor &workspace, MemoryGroup &memory_group, T &asm_glue) +inline bool setup_assembly_kernel(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, + Tensor &workspace, Tensor &B_pretranspose, MemoryGroup &memory_group, T &asm_glue) { const CPUInfo &ci = NEScheduler::get().cpu_info(); const int M = d->info()->tensor_shape().y(); @@ -147,7 +158,7 @@ inline bool setup_assembly_kernel(const ITensor *a, const ITensor *b, ITensor *d // unique_ptr to a Gemm object std::unique_ptr<typename T::AssemblyGemm> - asm_gemm(arm_gemm::gemm<typename T::TypeOperator, typename T::TypeResult>(ci, M, N, K, batches, multis, false, false, alpha, beta, num_threads, false)); + asm_gemm(arm_gemm::gemm<typename T::TypeOperator, typename T::TypeResult>(ci, M, N, K, batches, multis, false, false, alpha, beta, num_threads, pretranspose_hint)); // arm_compute wrapper for the Gemm object (see above) std::unique_ptr<NEGEMMAssemblyWrapper<typename T::AssemblyGemm>> acl_gemm_wrapper = support::cpp14::make_unique<NEGEMMAssemblyWrapper<typename T::AssemblyGemm>>(); @@ -159,7 +170,7 @@ inline bool setup_assembly_kernel(const ITensor *a, const ITensor *b, ITensor *d { // Allocate workspace const unsigned int alignment = 4096; - allocate_workspace(workspace_size, workspace, memory_group, alignment, num_threads); + allocate_workspace(workspace_size, workspace, &memory_group, alignment, num_threads); ARM_COMPUTE_ERROR_ON_NULLPTR(workspace.buffer()); asm_gemm->set_working_space(reinterpret_cast<typename T::TypeResult *>(workspace.buffer())); } @@ -175,6 +186,15 @@ inline bool setup_assembly_kernel(const ITensor *a, const ITensor *b, ITensor *d } } + // Check for pre-transposed support + if(asm_gemm->B_pretranspose_required()) + { + const size_t B_pretranspose_size = asm_gemm->get_B_pretransposed_array_size(); + allocate_workspace(B_pretranspose_size, B_pretranspose, nullptr, 1, 1); + ARM_COMPUTE_ERROR_ON_NULLPTR(B_pretranspose.buffer()); + asm_glue._pretranspose = &B_pretranspose; + } + asm_glue._gemm_kernel_asm = std::move(asm_gemm); asm_glue._optimised_kernel = std::move(acl_gemm_wrapper); // We need to setup the ptrs in the run() method |