From eb82fd2aa786715c3b6a941dc6d6deac4ce8e2a0 Mon Sep 17 00:00:00 2001 From: Pablo Tello Date: Fri, 23 Feb 2018 13:43:50 +0000 Subject: COMPMID-881: RSH new arm_gemm interface. Change-Id: I1e2a1a77097d8017c274af3f97eba6964f80f5fa Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/122592 Tested-by: Jenkins Reviewed-by: Anthony Barbier --- src/runtime/NEON/functions/NEGEMM.cpp | 118 ++------------------ .../NEON/functions/NEGEMMConvolutionLayer.cpp | 79 ++------------ .../NEGEMMLowpAssemblyMatrixMultiplyCore.cpp | 117 +++++--------------- .../functions/NEGEMMLowpMatrixMultiplyCore.cpp | 120 +++++++++------------ 4 files changed, 97 insertions(+), 337 deletions(-) (limited to 'src/runtime/NEON') diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp index 05907bab07..c8cba8a174 100644 --- a/src/runtime/NEON/functions/NEGEMM.cpp +++ b/src/runtime/NEON/functions/NEGEMM.cpp @@ -26,37 +26,20 @@ #include "arm_compute/core/Error.h" #include "arm_compute/core/Helpers.h" #include "arm_compute/core/ITensor.h" -#include "arm_compute/core/NEON/kernels/arm32/NEGEMMAArch32Kernel.h" -#include "arm_compute/core/NEON/kernels/arm64/NEGEMMAArch64Kernel.h" -#include "arm_compute/core/NEON/kernels/arm64/NEGEMVAArch64Kernel.h" -#include "arm_compute/core/NEON/kernels/arm64/NEHGEMMAArch64FP16Kernel.h" #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/Validate.h" +#include "arm_compute/runtime/NEON/AssemblyHelper.h" #include "arm_compute/runtime/NEON/NEScheduler.h" #include "arm_compute/runtime/TensorAllocator.h" #include "support/ToolchainSupport.h" -namespace arm_compute -{ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wswitch-default" -#pragma GCC diagnostic ignored "-Weffc++" -#include "arm_compute/core/NEON/kernels/assembly/gemm_interleaved.hpp" -#include "arm_compute/core/NEON/kernels/assembly/gemv_transposed.hpp" -#include "arm_compute/core/NEON/kernels/assembly/kernels/a32_sgemm_8x6.hpp" -#include "arm_compute/core/NEON/kernels/assembly/kernels/a64_hgemm_24x8.hpp" -#include "arm_compute/core/NEON/kernels/assembly/kernels/a64_sgemm_12x8.hpp" -#include "arm_compute/core/NEON/kernels/assembly/kernels/a64_sgemv_trans.hpp" -#pragma GCC diagnostic pop -} // namespace arm_compute - #include namespace arm_compute { NEGEMM::NEGEMM(std::shared_ptr memory_manager) - : _memory_group(std::move(memory_manager)), _interleave_kernel(), _transpose_kernel(), _mm_kernel(), _mm_optimised_kernel(nullptr), _ma_kernel(), _tmp_a(), _tmp_b(), _workspace(), + : _memory_group(std::move(memory_manager)), _interleave_kernel(), _transpose_kernel(), _mm_kernel(), _asm_glue(), _ma_kernel(), _tmp_a(), _tmp_b(), _workspace(), _run_vector_matrix_multiplication(false), _run_addition(false), _is_first_run(true), _reshape_b_only_on_first_run(false) { } @@ -82,42 +65,13 @@ void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITe // Check if we need to reshape the matrix B only on the first run _reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run(); _run_vector_matrix_multiplication = a->info()->dimension(1) < 2; + const bool run_optimised = setup_assembly_kernel(a, b, c, d, alpha, beta, _workspace, _memory_group, _asm_glue); // Check if the first input tensor is a vector. // If so, all the kernels for reshaping the tensors can be skipped if(_run_vector_matrix_multiplication) { -#if defined(__aarch64__) - if(NEScheduler::get().cpu_info().CPU >= CPUTarget::ARMV8 && a->info()->data_type() == DataType::F32 && (c == nullptr || beta == 0.f)) - { - _mm_optimised_kernel = support::cpp14::make_unique(); - } - - if(_mm_optimised_kernel != nullptr) - { - struct CPUInfo ci = NEScheduler::get().cpu_info(); - - const int N = d->info()->tensor_shape().x(); - const int K = a->info()->tensor_shape().x(); - - size_t workbench_size = 0; - - if(a->info()->data_type() == DataType::F32) - { - workbench_size = GemvTransposed(&ci, N, K).get_working_size(); - } - - constexpr size_t alignment = 4096; - ARM_COMPUTE_ERROR_ON_MSG(workbench_size == 0, "size cannot be 0"); - _workspace.allocator()->init(TensorInfo(TensorShape{ (workbench_size + alignment - 1) * NEScheduler::get().num_threads() }, 1, DataType::S8)); - _memory_group.manage(&_workspace); - - // Configure matrix multiplication kernel - _mm_optimised_kernel->configure(a, b, d, &_workspace, alpha, 0.f, false /* is_transposed_0 */, false /* is_transposed_1 */); - _workspace.allocator()->allocate(); - } - else -#endif /* defined(__aarch64__) */ + if(!run_optimised) { // Configure the matrix multiply kernel _mm_kernel.configure(a, b, d, alpha, false); @@ -132,65 +86,7 @@ void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITe } else { -#if defined(__arm__) - if(NEScheduler::get().cpu_info().CPU == CPUTarget::ARMV7 && a->info()->data_type() == DataType::F32 && (c == nullptr || beta == 0.f)) - { - _mm_optimised_kernel = support::cpp14::make_unique(); - } -#elif defined(__aarch64__) - if(NEScheduler::get().cpu_info().CPU >= CPUTarget::ARMV8 && a->info()->data_type() == DataType::F32 && (c == nullptr || beta == 0.f)) - { - _mm_optimised_kernel = support::cpp14::make_unique(); - } - else if(a->info()->data_type() == DataType::F16 && (c == nullptr || beta == 0.f)) - { -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - _mm_optimised_kernel = support::cpp14::make_unique(); -#else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - ARM_COMPUTE_ERROR("Recompile the library with arch=arm64-v8.2-a to enable support for FP16."); -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - } -#endif /* defined(__arm__) || defined(__aarch64__) */ - -#if defined(__arm__) || defined(__aarch64__) - if(_mm_optimised_kernel != nullptr) - { - struct CPUInfo ci = NEScheduler::get().cpu_info(); - - const int M = d->info()->tensor_shape().y(); - const int N = d->info()->tensor_shape().x(); - const int K = a->info()->tensor_shape().x(); - - size_t workbench_size = 0; - -#if defined(__arm__) - workbench_size = GemmInterleaved(&ci, M, N, K, false, false).get_working_size(); -#elif defined(__aarch64__) - if(a->info()->data_type() == DataType::F32) - { - workbench_size = GemmInterleaved(&ci, M, N, K, false, false).get_working_size(); - } - else if(a->info()->data_type() == DataType::F16) - { -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - workbench_size = GemmInterleaved(&ci, M, N, K, false, false).get_working_size(); -#else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - ARM_COMPUTE_ERROR("Recompile the library with arch=arm64-v8.2-a to enable support for FP16."); -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - } -#endif /* defined(__arm__) || defined(__aarch64__) */ - - constexpr size_t alignment = 4096; - ARM_COMPUTE_ERROR_ON_MSG(workbench_size == 0, "size cannot be 0"); - _workspace.allocator()->init(TensorInfo(TensorShape{ (workbench_size + alignment - 1) * NEScheduler::get().num_threads() }, 1, DataType::S8)); - _memory_group.manage(&_workspace); - - // Configure matrix multiplication kernel - _mm_optimised_kernel->configure(a, b, d, &_workspace, alpha, 0.f, false /* is_transposed_0 */, false /* is_transposed_1 */); - _workspace.allocator()->allocate(); - } - else -#endif /* defined(__arm__) || defined(__aarch64__) */ + if(!run_optimised) { TensorShape shape_tmp_a = a->info()->tensor_shape(); TensorShape shape_tmp_b = b->info()->tensor_shape(); @@ -243,9 +139,9 @@ void NEGEMM::run() { _memory_group.acquire(); - if(_mm_optimised_kernel != nullptr) + if(_asm_glue._optimised_kernel != nullptr) { - NEScheduler::get().schedule(_mm_optimised_kernel.get(), Window::DimY); + _asm_glue.run(); _memory_group.release(); } else diff --git a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp index a85078cf71..3b8b4243e5 100644 --- a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp +++ b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp @@ -23,9 +23,6 @@ */ #include "arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h" -#include "arm_compute/core/NEON/kernels/arm32/NEGEMMAArch32Kernel.h" -#include "arm_compute/core/NEON/kernels/arm64/NEGEMMAArch64Kernel.h" -#include "arm_compute/core/NEON/kernels/arm64/NEGEMMAArch64NativeKernel.h" #include "arm_compute/core/PixelValue.h" #include "arm_compute/core/Size2D.h" #include "arm_compute/core/Utils.h" @@ -34,13 +31,6 @@ #include "arm_compute/runtime/NEON/NEScheduler.h" #include "support/ToolchainSupport.h" -namespace arm_compute -{ -#include "arm_compute/core/NEON/kernels/assembly/gemm_interleaved.hpp" -#include "arm_compute/core/NEON/kernels/assembly/kernels/a32_sgemm_8x6.hpp" -#include "arm_compute/core/NEON/kernels/assembly/kernels/a64_sgemm_12x8.hpp" -} // namespace arm_compute - #include #include @@ -226,8 +216,8 @@ Status validate_and_initialize_values(const ITensorInfo *input, const ITensorInf } // namespace NEGEMMConvolutionLayer::NEGEMMConvolutionLayer(const std::shared_ptr &memory_manager) - : _memory_group(memory_manager), _input_im2col_kernel(), _input_interleave_kernel(), _reshape_weights(), _mm_kernel(), _mm_optimised_kernel(nullptr), _mm_gemmlowp(memory_manager), - _gemmlowp_output_stage(), _output_col2im_kernel(), _input_im2col_reshaped(), _input_interleaved_reshaped(), _weights_reshaped(), _gemm_output(), _tmp_output(), _workspace(), _append_bias(false), + : _asm_glue(), _memory_group(memory_manager), _input_im2col_kernel(), _input_interleave_kernel(), _reshape_weights(), _mm_kernel(), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(), + _output_col2im_kernel(), _input_im2col_reshaped(), _input_interleaved_reshaped(), _weights_reshaped(), _gemm_output(), _tmp_output(), _workspace(), _append_bias(false), _is_fully_connected_convolution(false), _are_weights_reshaped(false), _is_quantized(false), _is_interleaved(false) { } @@ -256,25 +246,6 @@ void NEGEMMConvolutionLayer::configure_mm(const ITensor *input, const ITensor *w } } -void NEGEMMConvolutionLayer::configure_asm_mm(const struct CPUInfo &ci, int M, int N, int K) -{ - ARM_COMPUTE_UNUSED(ci); - ARM_COMPUTE_UNUSED(M); - ARM_COMPUTE_UNUSED(N); - ARM_COMPUTE_UNUSED(K); -#if defined(__arm__) || defined(__aarch64__) -#if defined(__arm__) - GemmInterleaved gemm(&ci, M, N, K, false, false); -#elif defined(__aarch64__) - GemmInterleaved gemm(&ci, M, N, K, false, false); -#endif /* defined(__arm__) || defined(__aarch64__) */ - - constexpr size_t alignment = 4096; - _workspace.allocator()->init(TensorInfo(TensorShape{ (gemm.get_working_size() + alignment - 1) * NEScheduler::get().num_threads() }, 1, DataType::U8)); - _memory_group.manage(&_workspace); -#endif /* defined(__arm__) || defined(__aarch64__) */ -} - void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info, const WeightsInfo &weights_info) { // Perform validate step @@ -298,20 +269,11 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig const unsigned int fixed_point_position = input->info()->fixed_point_position(); const ITensor *biases_to_use = (_append_bias) ? biases : nullptr; -#if defined(__arm__) - if(NEScheduler::get().cpu_info().CPU == CPUTarget::ARMV7 && dt == DataType::F32) - { - _mm_optimised_kernel = support::cpp14::make_unique(); - } -#elif defined(__aarch64__) - if(NEScheduler::get().cpu_info().CPU >= CPUTarget::ARMV8 && dt == DataType::F32) - { - _mm_optimised_kernel = support::cpp14::make_unique(); - } -#endif /* defined(__arm__) || defined(__aarch64__) */ + bool run_optimised = + (NEScheduler::get().cpu_info().CPU == CPUTarget::ARMV7 && dt == DataType::F32) || (NEScheduler::get().cpu_info().CPU >= CPUTarget::ARMV8 && dt == DataType::F32); // Reshape weights if needed - if(_mm_optimised_kernel != nullptr) + if(run_optimised) { if(_are_weights_reshaped) { @@ -378,7 +340,7 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig _memory_group.manage(&_input_im2col_reshaped); // Create tensor (interleave) to prepare input tensor for GEMM - if(!_is_fully_connected_convolution && _mm_optimised_kernel == nullptr) + if(!_is_fully_connected_convolution && !run_optimised) { TensorShape shape_interleaved(shape_im2col); shape_interleaved.set(0, shape_interleaved.x() * 4); @@ -403,29 +365,10 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig _input_im2col_kernel.configure(input, &_input_im2col_reshaped, Size2D(kernel_width, kernel_height), conv_info, _append_bias); // Configure matrix multiply - if(_mm_optimised_kernel != nullptr) + if(run_optimised) { - struct CPUInfo ci = NEScheduler::get().cpu_info(); - - const int M = _gemm_output.info()->tensor_shape().y(); - const int N = _gemm_output.info()->tensor_shape().x(); - const int K = _input_im2col_reshaped.info()->tensor_shape().x(); - -#if defined(__aarch64__) - if((N <= 128) && (K <= 128)) - { - _mm_optimised_kernel = support::cpp14::make_unique(); - } - else -#endif /* defined(__aarch64__) */ - { - configure_asm_mm(ci, M, N, K); - } - - // Configure matrix multiplication kernel - _mm_optimised_kernel->configure(&_input_im2col_reshaped, weights, &_gemm_output, &_workspace); - - _workspace.allocator()->allocate(); + run_optimised = setup_assembly_kernel(&_input_im2col_reshaped, weights, nullptr, &_gemm_output, 1.f, 0.f, _workspace, _memory_group, _asm_glue); + ARM_COMPUTE_ERROR_ON_MSG(run_optimised == false, "setup_assembly_kernel failed."); } else { @@ -615,9 +558,9 @@ void NEGEMMConvolutionLayer::run() NEScheduler::get().schedule(&_input_im2col_kernel, Window::DimY); // Runs matrix multiply on reshaped matrices - if(_mm_optimised_kernel != nullptr) + if(_asm_glue._optimised_kernel != nullptr) { - NEScheduler::get().schedule(_mm_optimised_kernel.get(), Window::DimY); + _asm_glue.run(); } else { diff --git a/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp index 9b36e81afd..e5e97910d8 100644 --- a/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp +++ b/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp @@ -1,4 +1,4 @@ -/* Copyright (c) 2017 ARM Limited. +/* Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -25,13 +25,9 @@ #include "arm_compute/core/Error.h" #include "arm_compute/core/Helpers.h" #include "arm_compute/core/ITensor.h" -#include "arm_compute/core/NEON/kernels/NEGEMMAssemblyBaseKernel.h" #include "arm_compute/core/NEON/kernels/NEGEMMInterleave4x4Kernel.h" #include "arm_compute/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.h" #include "arm_compute/core/NEON/kernels/NEGEMMTranspose1xWKernel.h" -#include "arm_compute/core/NEON/kernels/arm64/NEGEMMLowpAArch64A53Kernel.h" -#include "arm_compute/core/NEON/kernels/arm64/NEGEMMLowpAArch64Kernel.h" -#include "arm_compute/core/NEON/kernels/arm64/NEGEMMLowpAArch64V8P4Kernel.h" #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/Validate.h" @@ -39,20 +35,11 @@ #include "arm_compute/runtime/TensorAllocator.h" #include "support/ToolchainSupport.h" -namespace arm_compute -{ -#include "arm_compute/core/NEON/kernels/assembly/gemm_interleaved.hpp" -#include "arm_compute/core/NEON/kernels/assembly/kernels/a64_gemm_s16_12x8.hpp" -#include "arm_compute/core/NEON/kernels/assembly/kernels/a64_gemm_s8_12x8.hpp" -#include "arm_compute/core/NEON/kernels/assembly/kernels/a64_gemm_s8_4x4.hpp" -#include "arm_compute/core/NEON/kernels/assembly/kernels/a64_gemm_u16_12x8.hpp" -#include "arm_compute/core/NEON/kernels/assembly/kernels/a64_gemm_u8_4x4.hpp" -} // namespace arm_compute - using namespace arm_compute; NEGEMMLowpAssemblyMatrixMultiplyCore::NEGEMMLowpAssemblyMatrixMultiplyCore(std::shared_ptr memory_manager) - : _memory_group(std::move(memory_manager)), _mm_kernel(nullptr), _mtx_a_reshape_kernel(nullptr), _mtx_b_reshape_kernel(nullptr), _tmp_a(), _tmp_b(), _workspace() + : _memory_group(std::move(memory_manager)), _asm_glue_unsigned(), _asm_glue_signed(), _mm_kernel(nullptr), _mtx_a_reshape_kernel(nullptr), _mtx_b_reshape_kernel(nullptr), _tmp_a(), _tmp_b(), + _workspace() { } @@ -65,89 +52,28 @@ void NEGEMMLowpAssemblyMatrixMultiplyCore::configure(const ITensor *a, const ITe ARM_COMPUTE_ERROR_ON_MSG((a)->info()->dimension(1) != (output)->info()->dimension(1), "The output matrix must have the same number of rows as the matrix A"); ARM_COMPUTE_ERROR_ON_MSG((b)->info()->dimension(0) != (output)->info()->dimension(0), "The output matrix must have the same number of columns as the matrix B"); + bool run_optimised = false; #ifdef __aarch64__ - const int M = output->info()->tensor_shape().y(); - const int N = output->info()->tensor_shape().x(); - const int K = a->info()->tensor_shape().x(); - constexpr size_t workspace_alignment = 4096; - const struct CPUInfo ci = NEScheduler::get().cpu_info(); -#endif /* __aarch64__ */ - -#ifdef ARM_COMPUTE_AARCH64_V8_2 - if(ci.CPU == CPUTarget::A75_DOT || ci.CPU == CPUTarget::A55_DOT) - { - // Configure matrix multiply kernel - GemmInterleaved gemm(&ci, M, N, K, false, false); - _workspace.allocator()->init(TensorInfo(TensorShape{ (gemm.get_working_size() + workspace_alignment - 1) * NEScheduler::get().num_threads() }, 1, DataType::U8)); - _memory_group.manage(&_workspace); - - // Configure matrix multiplication kernel - auto k = arm_compute::support::cpp14::make_unique(); - k->configure(a, b, output, &_workspace, 1.f, 1.f); - _mm_kernel = std::move(k); - _workspace.allocator()->allocate(); - } - else -#elif defined(ARM_COMPUTE_AARCH64_V8A) - if(ci.CPU == CPUTarget::A53) + switch(a->info()->data_type()) { - switch(a->info()->data_type()) + case DataType::S8: { - case DataType::S8: - { - // Configure matrix multiply kernel - GemmInterleaved gemm(&ci, M, N, K, false, false); - _workspace.allocator()->init(TensorInfo(TensorShape{ (gemm.get_working_size() + workspace_alignment - 1) * NEScheduler::get().num_threads() }, 1, DataType::U8)); - } + run_optimised = setup_assembly_kernel(a, b, nullptr, output, 1.f, 1.f, _workspace, _memory_group, _asm_glue_signed); break; - case DataType::U8: - { - // Configure matrix multiply kernel - GemmInterleaved gemm(&ci, M, N, K, false, false); - _workspace.allocator()->init(TensorInfo(TensorShape{ (gemm.get_working_size() + workspace_alignment - 1) * NEScheduler::get().num_threads() }, 1, DataType::U8)); - } - break; - default: - ARM_COMPUTE_ERROR("Datatype not supported"); } - - _memory_group.manage(&_workspace); - // Configure matrix multiplication kernel - auto k = arm_compute::support::cpp14::make_unique(); - k->configure(a, b, output, &_workspace, 1.f, 1.f); - _mm_kernel = std::move(k); - _workspace.allocator()->allocate(); - } - else if(1) // Generic v8a kernel - { - switch(a->info()->data_type()) + case DataType::U8: { - case DataType::S8: - { - // Configure matrix multiply kernel - GemmInterleaved gemm(&ci, M, N, K, false, false); - _workspace.allocator()->init(TensorInfo(TensorShape{ (gemm.get_working_size() + workspace_alignment - 1) * NEScheduler::get().num_threads() }, 1, DataType::U8)); - } + run_optimised = setup_assembly_kernel(a, b, nullptr, output, 1.f, 1.f, _workspace, _memory_group, _asm_glue_unsigned); break; - case DataType::U8: - { - // Configure matrix multiply kernel - GemmInterleaved gemm(&ci, M, N, K, false, false); - _workspace.allocator()->init(TensorInfo(TensorShape{ (gemm.get_working_size() + workspace_alignment - 1) * NEScheduler::get().num_threads() }, 1, DataType::U8)); - } + } + default: + { + ARM_COMPUTE_ERROR("Datatype not supported"); break; - default: - ARM_COMPUTE_ERROR("Datatype not supported"); } - _memory_group.manage(&_workspace); - // Configure matrix multiplication kernel - auto k = arm_compute::support::cpp14::make_unique(); - k->configure(a, b, output, &_workspace, 1.f, 1.f); - _mm_kernel = std::move(k); - _workspace.allocator()->allocate(); } - else -#endif /* ARM_COMPUTE_AARCH64_V8_2 */ +#endif /* __aarch64__ */ + if(!run_optimised) { // The interleaved output matrix will have the following shape: [ a_height * 4, ceil(a_width / 4.0f) ] TensorShape shape_tmp_a = a->info()->tensor_shape(); @@ -206,7 +132,18 @@ void NEGEMMLowpAssemblyMatrixMultiplyCore::run() NEScheduler::get().schedule(_mtx_b_reshape_kernel.get(), Window::DimY); } - NEScheduler::get().schedule(_mm_kernel.get(), Window::DimY); + if(_asm_glue_unsigned._optimised_kernel != nullptr) + { + _asm_glue_unsigned.run(); + } + else if(_asm_glue_signed._optimised_kernel != nullptr) + { + _asm_glue_signed.run(); + } + else + { + NEScheduler::get().schedule(_mm_kernel.get(), Window::DimY); + } _memory_group.release(); } diff --git a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp index ad47593f20..dc4ed5cefb 100644 --- a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp +++ b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp @@ -26,11 +26,9 @@ #include "arm_compute/core/Error.h" #include "arm_compute/core/Helpers.h" #include "arm_compute/core/ITensor.h" -#include "arm_compute/core/NEON/kernels/NEGEMMAssemblyBaseKernel.h" #include "arm_compute/core/NEON/kernels/NEGEMMInterleave4x4Kernel.h" #include "arm_compute/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.h" #include "arm_compute/core/NEON/kernels/NEGEMMTranspose1xWKernel.h" -#include "arm_compute/core/NEON/kernels/arm64/NEGEMMLowpAArch64V8P4Kernel.h" #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/Validate.h" @@ -39,18 +37,13 @@ #include "arm_compute/runtime/TensorAllocator.h" #include "support/ToolchainSupport.h" -namespace arm_compute -{ -#include "arm_compute/core/NEON/kernels/assembly/gemm_interleaved.hpp" -#include "arm_compute/core/NEON/kernels/assembly/kernels/a64_gemm_u8_12x8.hpp" -} // namespace arm_compute - using namespace arm_compute; using namespace arm_compute::misc::shape_calculator; NEGEMMLowpMatrixMultiplyCore::NEGEMMLowpMatrixMultiplyCore(std::shared_ptr memory_manager) - : _memory_group(std::move(memory_manager)), _mm_kernel(nullptr), _mtx_a_reshape_kernel(nullptr), _mtx_b_reshape_kernel(nullptr), _mtx_a_reduction_kernel(), _mtx_b_reduction_kernel(), - _offset_contribution_kernel(), _vector_sum_col(), _vector_sum_row(), _tmp_a(), _tmp_b(), _workspace(), _a_offset(0), _b_offset(0), _run_vector_matrix_multiplication(false), _dot_product_path(false) + : _memory_group(std::move(memory_manager)), _asm_glue_unsigned(), _asm_glue_signed(), _mm_kernel(nullptr), _mtx_a_reshape_kernel(nullptr), _mtx_b_reshape_kernel(nullptr), _mtx_a_reduction_kernel(), + _mtx_b_reduction_kernel(), _offset_contribution_kernel(), _vector_sum_col(), _vector_sum_row(), _tmp_a(), _tmp_b(), _workspace(), _a_offset(0), _b_offset(0), _run_vector_matrix_multiplication(false), + _dot_product_path(false) { } @@ -64,33 +57,27 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b, _b_offset = b->info()->quantization_info().offset; _run_vector_matrix_multiplication = a->info()->dimension(1) < 2; -#ifdef ARM_COMPUTE_AARCH64_V8_2 - // Check for DOT product instruction - const struct CPUInfo ci = NEScheduler::get().cpu_info(); - const int cpu_has_dotprod = static_cast(ci.CPU) & static_cast(CPUTarget::DOT); - - if(cpu_has_dotprod != 0) +#ifdef __aarch64__ + switch(a->info()->data_type()) { - _dot_product_path = true; - - // Configure matrix multiply kernel - struct CPUInfo ci = NEScheduler::get().cpu_info(); - const int M = output->info()->tensor_shape().y(); - const int N = output->info()->tensor_shape().x(); - const int K = a->info()->tensor_shape().x(); - - const size_t workbench_size = GemmInterleaved(&ci, M, N, K, false, false).get_working_size(); - constexpr size_t alignment = 4096; - _workspace.allocator()->init(TensorInfo(TensorShape{ (workbench_size + alignment - 1) * NEScheduler::get().num_threads() }, 1, DataType::U8)); - _memory_group.manage(&_workspace); - - // Configure matrix multiplication kernel - auto k = arm_compute::support::cpp14::make_unique(); - k->configure(a, b, output, &_workspace, 1.f, 1.f, false, false); - _mm_kernel = std::move(k); + case DataType::S8: + { + _dot_product_path = setup_assembly_kernel(a, b, nullptr, output, 1.f, 1.f, _workspace, _memory_group, _asm_glue_signed); + break; + } + case DataType::U8: + { + _dot_product_path = setup_assembly_kernel(a, b, nullptr, output, 1.f, 1.f, _workspace, _memory_group, _asm_glue_unsigned); + break; + } + default: + { + ARM_COMPUTE_ERROR("Datatype not supported"); + break; + } } - else -#endif /* ARM_COMPUTE_AARCH64_V8_2 */ +#endif /* __aarch64__ */ + if(!_dot_product_path) { if(_run_vector_matrix_multiplication) { @@ -203,42 +190,28 @@ Status NEGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso int32_t b_offset = b->quantization_info().offset; bool run_vector_matrix_multiplication = a->dimension(1) < 2; -#ifdef ARM_COMPUTE_AARCH64_V8_2 - // Check for DOT product instruction - const struct CPUInfo ci = NEScheduler::get().cpu_info(); - const int cpu_has_dotprod = static_cast(ci.CPU) & static_cast(CPUTarget::DOT); - - if(cpu_has_dotprod != 0) + if(!run_vector_matrix_multiplication) { - // Validate matrix multiply kernel - ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpAArch64V8P4Kernel::validate(a, b, output)); + // The interleaved output matrix will have the following shape: [ a_height * 4, ceil(a_width / 4.0f) ] + TensorShape shape_tmp_a = a->tensor_shape(); + shape_tmp_a.set(0, a->dimension(0) * 4); + shape_tmp_a.set(1, std::ceil(a->dimension(1) / 4.f)); + + // The transpose1xW output matrix will have the following shape: [ b_height * 16, ceil(b_width / 16.0f) ] + TensorShape shape_tmp_b = b->tensor_shape(); + shape_tmp_b.set(0, b->dimension(1) * 16); + shape_tmp_b.set(1, std::ceil(b->dimension(0) / 16.f)); + + TensorInfo info_a(shape_tmp_a, 1, a->data_type()); + TensorInfo info_b(shape_tmp_b, 1, b->data_type()); + + ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMInterleave4x4Kernel::validate(a, &info_a)); + ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMTranspose1xWKernel::validate(b, &info_b)); + ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixMultiplyKernel::validate(&info_a, &info_b, output)); } else -#endif /* ARM_COMPUTE_AARCH64_V8_2 */ { - if(!run_vector_matrix_multiplication) - { - // The interleaved output matrix will have the following shape: [ a_height * 4, ceil(a_width / 4.0f) ] - TensorShape shape_tmp_a = a->tensor_shape(); - shape_tmp_a.set(0, a->dimension(0) * 4); - shape_tmp_a.set(1, std::ceil(a->dimension(1) / 4.f)); - - // The transpose1xW output matrix will have the following shape: [ b_height * 16, ceil(b_width / 16.0f) ] - TensorShape shape_tmp_b = b->tensor_shape(); - shape_tmp_b.set(0, b->dimension(1) * 16); - shape_tmp_b.set(1, std::ceil(b->dimension(0) / 16.f)); - - TensorInfo info_a(shape_tmp_a, 1, a->data_type()); - TensorInfo info_b(shape_tmp_b, 1, b->data_type()); - - ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMInterleave4x4Kernel::validate(a, &info_a)); - ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMTranspose1xWKernel::validate(b, &info_b)); - ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixMultiplyKernel::validate(&info_a, &info_b, output)); - } - else - { - ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixMultiplyKernel::validate(a, b, output)); - } + ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixMultiplyKernel::validate(a, b, output)); } TensorInfo info_vector_sum_col, info_vector_sum_row; @@ -288,7 +261,18 @@ void NEGEMMLowpMatrixMultiplyCore::run() } } - NEScheduler::get().schedule(_mm_kernel.get(), Window::DimY); + if(_asm_glue_unsigned._optimised_kernel != nullptr) + { + _asm_glue_unsigned.run(); + } + else if(_asm_glue_signed._optimised_kernel != nullptr) + { + _asm_glue_signed.run(); + } + else + { + NEScheduler::get().schedule(_mm_kernel.get(), Window::DimY); + } // Run matrix A reduction kernel only if _b_offset is not equal to 0 if(_b_offset != 0) -- cgit v1.2.1