aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON
diff options
context:
space:
mode:
authorPablo Tello <pablo.tello@arm.com>2018-02-23 13:43:50 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:49:16 +0000
commiteb82fd2aa786715c3b6a941dc6d6deac4ce8e2a0 (patch)
tree42cca378eed97c07348f28e1ec708d9c7ed531ce /src/runtime/NEON
parent8df6c452820719d201ee79596cde8445c2071db5 (diff)
downloadComputeLibrary-eb82fd2aa786715c3b6a941dc6d6deac4ce8e2a0.tar.gz
COMPMID-881: RSH new arm_gemm interface.
Change-Id: I1e2a1a77097d8017c274af3f97eba6964f80f5fa Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/122592 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/runtime/NEON')
-rw-r--r--src/runtime/NEON/functions/NEGEMM.cpp118
-rw-r--r--src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp79
-rw-r--r--src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp117
-rw-r--r--src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp120
4 files changed, 97 insertions, 337 deletions
diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp
index 05907bab07..c8cba8a174 100644
--- a/src/runtime/NEON/functions/NEGEMM.cpp
+++ b/src/runtime/NEON/functions/NEGEMM.cpp
@@ -26,37 +26,20 @@
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/ITensor.h"
-#include "arm_compute/core/NEON/kernels/arm32/NEGEMMAArch32Kernel.h"
-#include "arm_compute/core/NEON/kernels/arm64/NEGEMMAArch64Kernel.h"
-#include "arm_compute/core/NEON/kernels/arm64/NEGEMVAArch64Kernel.h"
-#include "arm_compute/core/NEON/kernels/arm64/NEHGEMMAArch64FP16Kernel.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Validate.h"
+#include "arm_compute/runtime/NEON/AssemblyHelper.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
#include "arm_compute/runtime/TensorAllocator.h"
#include "support/ToolchainSupport.h"
-namespace arm_compute
-{
-#pragma GCC diagnostic push
-#pragma GCC diagnostic ignored "-Wswitch-default"
-#pragma GCC diagnostic ignored "-Weffc++"
-#include "arm_compute/core/NEON/kernels/assembly/gemm_interleaved.hpp"
-#include "arm_compute/core/NEON/kernels/assembly/gemv_transposed.hpp"
-#include "arm_compute/core/NEON/kernels/assembly/kernels/a32_sgemm_8x6.hpp"
-#include "arm_compute/core/NEON/kernels/assembly/kernels/a64_hgemm_24x8.hpp"
-#include "arm_compute/core/NEON/kernels/assembly/kernels/a64_sgemm_12x8.hpp"
-#include "arm_compute/core/NEON/kernels/assembly/kernels/a64_sgemv_trans.hpp"
-#pragma GCC diagnostic pop
-} // namespace arm_compute
-
#include <cmath>
namespace arm_compute
{
NEGEMM::NEGEMM(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(std::move(memory_manager)), _interleave_kernel(), _transpose_kernel(), _mm_kernel(), _mm_optimised_kernel(nullptr), _ma_kernel(), _tmp_a(), _tmp_b(), _workspace(),
+ : _memory_group(std::move(memory_manager)), _interleave_kernel(), _transpose_kernel(), _mm_kernel(), _asm_glue(), _ma_kernel(), _tmp_a(), _tmp_b(), _workspace(),
_run_vector_matrix_multiplication(false), _run_addition(false), _is_first_run(true), _reshape_b_only_on_first_run(false)
{
}
@@ -82,42 +65,13 @@ void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITe
// Check if we need to reshape the matrix B only on the first run
_reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
_run_vector_matrix_multiplication = a->info()->dimension(1) < 2;
+ const bool run_optimised = setup_assembly_kernel(a, b, c, d, alpha, beta, _workspace, _memory_group, _asm_glue);
// Check if the first input tensor is a vector.
// If so, all the kernels for reshaping the tensors can be skipped
if(_run_vector_matrix_multiplication)
{
-#if defined(__aarch64__)
- if(NEScheduler::get().cpu_info().CPU >= CPUTarget::ARMV8 && a->info()->data_type() == DataType::F32 && (c == nullptr || beta == 0.f))
- {
- _mm_optimised_kernel = support::cpp14::make_unique<NEGEMVAArch64Kernel>();
- }
-
- if(_mm_optimised_kernel != nullptr)
- {
- struct CPUInfo ci = NEScheduler::get().cpu_info();
-
- const int N = d->info()->tensor_shape().x();
- const int K = a->info()->tensor_shape().x();
-
- size_t workbench_size = 0;
-
- if(a->info()->data_type() == DataType::F32)
- {
- workbench_size = GemvTransposed<sgemv_trans, sgemv_trans::operand_type, sgemv_trans::result_type>(&ci, N, K).get_working_size();
- }
-
- constexpr size_t alignment = 4096;
- ARM_COMPUTE_ERROR_ON_MSG(workbench_size == 0, "size cannot be 0");
- _workspace.allocator()->init(TensorInfo(TensorShape{ (workbench_size + alignment - 1) * NEScheduler::get().num_threads() }, 1, DataType::S8));
- _memory_group.manage(&_workspace);
-
- // Configure matrix multiplication kernel
- _mm_optimised_kernel->configure(a, b, d, &_workspace, alpha, 0.f, false /* is_transposed_0 */, false /* is_transposed_1 */);
- _workspace.allocator()->allocate();
- }
- else
-#endif /* defined(__aarch64__) */
+ if(!run_optimised)
{
// Configure the matrix multiply kernel
_mm_kernel.configure(a, b, d, alpha, false);
@@ -132,65 +86,7 @@ void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITe
}
else
{
-#if defined(__arm__)
- if(NEScheduler::get().cpu_info().CPU == CPUTarget::ARMV7 && a->info()->data_type() == DataType::F32 && (c == nullptr || beta == 0.f))
- {
- _mm_optimised_kernel = support::cpp14::make_unique<NEGEMMAArch32Kernel>();
- }
-#elif defined(__aarch64__)
- if(NEScheduler::get().cpu_info().CPU >= CPUTarget::ARMV8 && a->info()->data_type() == DataType::F32 && (c == nullptr || beta == 0.f))
- {
- _mm_optimised_kernel = support::cpp14::make_unique<NEGEMMAArch64Kernel>();
- }
- else if(a->info()->data_type() == DataType::F16 && (c == nullptr || beta == 0.f))
- {
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- _mm_optimised_kernel = support::cpp14::make_unique<NEHGEMMAArch64FP16Kernel>();
-#else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
- ARM_COMPUTE_ERROR("Recompile the library with arch=arm64-v8.2-a to enable support for FP16.");
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
- }
-#endif /* defined(__arm__) || defined(__aarch64__) */
-
-#if defined(__arm__) || defined(__aarch64__)
- if(_mm_optimised_kernel != nullptr)
- {
- struct CPUInfo ci = NEScheduler::get().cpu_info();
-
- const int M = d->info()->tensor_shape().y();
- const int N = d->info()->tensor_shape().x();
- const int K = a->info()->tensor_shape().x();
-
- size_t workbench_size = 0;
-
-#if defined(__arm__)
- workbench_size = GemmInterleaved<sgemm_8x6, sgemm_8x6::operand_type, sgemm_8x6::result_type>(&ci, M, N, K, false, false).get_working_size();
-#elif defined(__aarch64__)
- if(a->info()->data_type() == DataType::F32)
- {
- workbench_size = GemmInterleaved<sgemm_12x8, sgemm_12x8::operand_type, sgemm_12x8::result_type>(&ci, M, N, K, false, false).get_working_size();
- }
- else if(a->info()->data_type() == DataType::F16)
- {
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- workbench_size = GemmInterleaved<hgemm_24x8, hgemm_24x8::operand_type, hgemm_24x8::result_type>(&ci, M, N, K, false, false).get_working_size();
-#else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
- ARM_COMPUTE_ERROR("Recompile the library with arch=arm64-v8.2-a to enable support for FP16.");
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
- }
-#endif /* defined(__arm__) || defined(__aarch64__) */
-
- constexpr size_t alignment = 4096;
- ARM_COMPUTE_ERROR_ON_MSG(workbench_size == 0, "size cannot be 0");
- _workspace.allocator()->init(TensorInfo(TensorShape{ (workbench_size + alignment - 1) * NEScheduler::get().num_threads() }, 1, DataType::S8));
- _memory_group.manage(&_workspace);
-
- // Configure matrix multiplication kernel
- _mm_optimised_kernel->configure(a, b, d, &_workspace, alpha, 0.f, false /* is_transposed_0 */, false /* is_transposed_1 */);
- _workspace.allocator()->allocate();
- }
- else
-#endif /* defined(__arm__) || defined(__aarch64__) */
+ if(!run_optimised)
{
TensorShape shape_tmp_a = a->info()->tensor_shape();
TensorShape shape_tmp_b = b->info()->tensor_shape();
@@ -243,9 +139,9 @@ void NEGEMM::run()
{
_memory_group.acquire();
- if(_mm_optimised_kernel != nullptr)
+ if(_asm_glue._optimised_kernel != nullptr)
{
- NEScheduler::get().schedule(_mm_optimised_kernel.get(), Window::DimY);
+ _asm_glue.run();
_memory_group.release();
}
else
diff --git a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
index a85078cf71..3b8b4243e5 100644
--- a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
@@ -23,9 +23,6 @@
*/
#include "arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h"
-#include "arm_compute/core/NEON/kernels/arm32/NEGEMMAArch32Kernel.h"
-#include "arm_compute/core/NEON/kernels/arm64/NEGEMMAArch64Kernel.h"
-#include "arm_compute/core/NEON/kernels/arm64/NEGEMMAArch64NativeKernel.h"
#include "arm_compute/core/PixelValue.h"
#include "arm_compute/core/Size2D.h"
#include "arm_compute/core/Utils.h"
@@ -34,13 +31,6 @@
#include "arm_compute/runtime/NEON/NEScheduler.h"
#include "support/ToolchainSupport.h"
-namespace arm_compute
-{
-#include "arm_compute/core/NEON/kernels/assembly/gemm_interleaved.hpp"
-#include "arm_compute/core/NEON/kernels/assembly/kernels/a32_sgemm_8x6.hpp"
-#include "arm_compute/core/NEON/kernels/assembly/kernels/a64_sgemm_12x8.hpp"
-} // namespace arm_compute
-
#include <cmath>
#include <tuple>
@@ -226,8 +216,8 @@ Status validate_and_initialize_values(const ITensorInfo *input, const ITensorInf
} // namespace
NEGEMMConvolutionLayer::NEGEMMConvolutionLayer(const std::shared_ptr<IMemoryManager> &memory_manager)
- : _memory_group(memory_manager), _input_im2col_kernel(), _input_interleave_kernel(), _reshape_weights(), _mm_kernel(), _mm_optimised_kernel(nullptr), _mm_gemmlowp(memory_manager),
- _gemmlowp_output_stage(), _output_col2im_kernel(), _input_im2col_reshaped(), _input_interleaved_reshaped(), _weights_reshaped(), _gemm_output(), _tmp_output(), _workspace(), _append_bias(false),
+ : _asm_glue(), _memory_group(memory_manager), _input_im2col_kernel(), _input_interleave_kernel(), _reshape_weights(), _mm_kernel(), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(),
+ _output_col2im_kernel(), _input_im2col_reshaped(), _input_interleaved_reshaped(), _weights_reshaped(), _gemm_output(), _tmp_output(), _workspace(), _append_bias(false),
_is_fully_connected_convolution(false), _are_weights_reshaped(false), _is_quantized(false), _is_interleaved(false)
{
}
@@ -256,25 +246,6 @@ void NEGEMMConvolutionLayer::configure_mm(const ITensor *input, const ITensor *w
}
}
-void NEGEMMConvolutionLayer::configure_asm_mm(const struct CPUInfo &ci, int M, int N, int K)
-{
- ARM_COMPUTE_UNUSED(ci);
- ARM_COMPUTE_UNUSED(M);
- ARM_COMPUTE_UNUSED(N);
- ARM_COMPUTE_UNUSED(K);
-#if defined(__arm__) || defined(__aarch64__)
-#if defined(__arm__)
- GemmInterleaved<sgemm_8x6, float, float> gemm(&ci, M, N, K, false, false);
-#elif defined(__aarch64__)
- GemmInterleaved<sgemm_12x8, float, float> gemm(&ci, M, N, K, false, false);
-#endif /* defined(__arm__) || defined(__aarch64__) */
-
- constexpr size_t alignment = 4096;
- _workspace.allocator()->init(TensorInfo(TensorShape{ (gemm.get_working_size() + alignment - 1) * NEScheduler::get().num_threads() }, 1, DataType::U8));
- _memory_group.manage(&_workspace);
-#endif /* defined(__arm__) || defined(__aarch64__) */
-}
-
void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info, const WeightsInfo &weights_info)
{
// Perform validate step
@@ -298,20 +269,11 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig
const unsigned int fixed_point_position = input->info()->fixed_point_position();
const ITensor *biases_to_use = (_append_bias) ? biases : nullptr;
-#if defined(__arm__)
- if(NEScheduler::get().cpu_info().CPU == CPUTarget::ARMV7 && dt == DataType::F32)
- {
- _mm_optimised_kernel = support::cpp14::make_unique<NEGEMMAArch32Kernel>();
- }
-#elif defined(__aarch64__)
- if(NEScheduler::get().cpu_info().CPU >= CPUTarget::ARMV8 && dt == DataType::F32)
- {
- _mm_optimised_kernel = support::cpp14::make_unique<NEGEMMAArch64Kernel>();
- }
-#endif /* defined(__arm__) || defined(__aarch64__) */
+ bool run_optimised =
+ (NEScheduler::get().cpu_info().CPU == CPUTarget::ARMV7 && dt == DataType::F32) || (NEScheduler::get().cpu_info().CPU >= CPUTarget::ARMV8 && dt == DataType::F32);
// Reshape weights if needed
- if(_mm_optimised_kernel != nullptr)
+ if(run_optimised)
{
if(_are_weights_reshaped)
{
@@ -378,7 +340,7 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig
_memory_group.manage(&_input_im2col_reshaped);
// Create tensor (interleave) to prepare input tensor for GEMM
- if(!_is_fully_connected_convolution && _mm_optimised_kernel == nullptr)
+ if(!_is_fully_connected_convolution && !run_optimised)
{
TensorShape shape_interleaved(shape_im2col);
shape_interleaved.set(0, shape_interleaved.x() * 4);
@@ -403,29 +365,10 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig
_input_im2col_kernel.configure(input, &_input_im2col_reshaped, Size2D(kernel_width, kernel_height), conv_info, _append_bias);
// Configure matrix multiply
- if(_mm_optimised_kernel != nullptr)
+ if(run_optimised)
{
- struct CPUInfo ci = NEScheduler::get().cpu_info();
-
- const int M = _gemm_output.info()->tensor_shape().y();
- const int N = _gemm_output.info()->tensor_shape().x();
- const int K = _input_im2col_reshaped.info()->tensor_shape().x();
-
-#if defined(__aarch64__)
- if((N <= 128) && (K <= 128))
- {
- _mm_optimised_kernel = support::cpp14::make_unique<NEGEMMAArch64NativeKernel>();
- }
- else
-#endif /* defined(__aarch64__) */
- {
- configure_asm_mm(ci, M, N, K);
- }
-
- // Configure matrix multiplication kernel
- _mm_optimised_kernel->configure(&_input_im2col_reshaped, weights, &_gemm_output, &_workspace);
-
- _workspace.allocator()->allocate();
+ run_optimised = setup_assembly_kernel(&_input_im2col_reshaped, weights, nullptr, &_gemm_output, 1.f, 0.f, _workspace, _memory_group, _asm_glue);
+ ARM_COMPUTE_ERROR_ON_MSG(run_optimised == false, "setup_assembly_kernel failed.");
}
else
{
@@ -615,9 +558,9 @@ void NEGEMMConvolutionLayer::run()
NEScheduler::get().schedule(&_input_im2col_kernel, Window::DimY);
// Runs matrix multiply on reshaped matrices
- if(_mm_optimised_kernel != nullptr)
+ if(_asm_glue._optimised_kernel != nullptr)
{
- NEScheduler::get().schedule(_mm_optimised_kernel.get(), Window::DimY);
+ _asm_glue.run();
}
else
{
diff --git a/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp
index 9b36e81afd..e5e97910d8 100644
--- a/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp
+++ b/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp
@@ -1,4 +1,4 @@
-/* Copyright (c) 2017 ARM Limited.
+/* Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -25,13 +25,9 @@
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/ITensor.h"
-#include "arm_compute/core/NEON/kernels/NEGEMMAssemblyBaseKernel.h"
#include "arm_compute/core/NEON/kernels/NEGEMMInterleave4x4Kernel.h"
#include "arm_compute/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.h"
#include "arm_compute/core/NEON/kernels/NEGEMMTranspose1xWKernel.h"
-#include "arm_compute/core/NEON/kernels/arm64/NEGEMMLowpAArch64A53Kernel.h"
-#include "arm_compute/core/NEON/kernels/arm64/NEGEMMLowpAArch64Kernel.h"
-#include "arm_compute/core/NEON/kernels/arm64/NEGEMMLowpAArch64V8P4Kernel.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Validate.h"
@@ -39,20 +35,11 @@
#include "arm_compute/runtime/TensorAllocator.h"
#include "support/ToolchainSupport.h"
-namespace arm_compute
-{
-#include "arm_compute/core/NEON/kernels/assembly/gemm_interleaved.hpp"
-#include "arm_compute/core/NEON/kernels/assembly/kernels/a64_gemm_s16_12x8.hpp"
-#include "arm_compute/core/NEON/kernels/assembly/kernels/a64_gemm_s8_12x8.hpp"
-#include "arm_compute/core/NEON/kernels/assembly/kernels/a64_gemm_s8_4x4.hpp"
-#include "arm_compute/core/NEON/kernels/assembly/kernels/a64_gemm_u16_12x8.hpp"
-#include "arm_compute/core/NEON/kernels/assembly/kernels/a64_gemm_u8_4x4.hpp"
-} // namespace arm_compute
-
using namespace arm_compute;
NEGEMMLowpAssemblyMatrixMultiplyCore::NEGEMMLowpAssemblyMatrixMultiplyCore(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(std::move(memory_manager)), _mm_kernel(nullptr), _mtx_a_reshape_kernel(nullptr), _mtx_b_reshape_kernel(nullptr), _tmp_a(), _tmp_b(), _workspace()
+ : _memory_group(std::move(memory_manager)), _asm_glue_unsigned(), _asm_glue_signed(), _mm_kernel(nullptr), _mtx_a_reshape_kernel(nullptr), _mtx_b_reshape_kernel(nullptr), _tmp_a(), _tmp_b(),
+ _workspace()
{
}
@@ -65,89 +52,28 @@ void NEGEMMLowpAssemblyMatrixMultiplyCore::configure(const ITensor *a, const ITe
ARM_COMPUTE_ERROR_ON_MSG((a)->info()->dimension(1) != (output)->info()->dimension(1), "The output matrix must have the same number of rows as the matrix A");
ARM_COMPUTE_ERROR_ON_MSG((b)->info()->dimension(0) != (output)->info()->dimension(0), "The output matrix must have the same number of columns as the matrix B");
+ bool run_optimised = false;
#ifdef __aarch64__
- const int M = output->info()->tensor_shape().y();
- const int N = output->info()->tensor_shape().x();
- const int K = a->info()->tensor_shape().x();
- constexpr size_t workspace_alignment = 4096;
- const struct CPUInfo ci = NEScheduler::get().cpu_info();
-#endif /* __aarch64__ */
-
-#ifdef ARM_COMPUTE_AARCH64_V8_2
- if(ci.CPU == CPUTarget::A75_DOT || ci.CPU == CPUTarget::A55_DOT)
- {
- // Configure matrix multiply kernel
- GemmInterleaved<gemm_s8_12x8, int8_t, int32_t> gemm(&ci, M, N, K, false, false);
- _workspace.allocator()->init(TensorInfo(TensorShape{ (gemm.get_working_size() + workspace_alignment - 1) * NEScheduler::get().num_threads() }, 1, DataType::U8));
- _memory_group.manage(&_workspace);
-
- // Configure matrix multiplication kernel
- auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpAArch64V8P4Kernel>();
- k->configure(a, b, output, &_workspace, 1.f, 1.f);
- _mm_kernel = std::move(k);
- _workspace.allocator()->allocate();
- }
- else
-#elif defined(ARM_COMPUTE_AARCH64_V8A)
- if(ci.CPU == CPUTarget::A53)
+ switch(a->info()->data_type())
{
- switch(a->info()->data_type())
+ case DataType::S8:
{
- case DataType::S8:
- {
- // Configure matrix multiply kernel
- GemmInterleaved<gemm_s16_12x8, int8_t, int32_t> gemm(&ci, M, N, K, false, false);
- _workspace.allocator()->init(TensorInfo(TensorShape{ (gemm.get_working_size() + workspace_alignment - 1) * NEScheduler::get().num_threads() }, 1, DataType::U8));
- }
+ run_optimised = setup_assembly_kernel(a, b, nullptr, output, 1.f, 1.f, _workspace, _memory_group, _asm_glue_signed);
break;
- case DataType::U8:
- {
- // Configure matrix multiply kernel
- GemmInterleaved<gemm_u16_12x8, uint8_t, uint32_t> gemm(&ci, M, N, K, false, false);
- _workspace.allocator()->init(TensorInfo(TensorShape{ (gemm.get_working_size() + workspace_alignment - 1) * NEScheduler::get().num_threads() }, 1, DataType::U8));
- }
- break;
- default:
- ARM_COMPUTE_ERROR("Datatype not supported");
}
-
- _memory_group.manage(&_workspace);
- // Configure matrix multiplication kernel
- auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpAArch64A53Kernel>();
- k->configure(a, b, output, &_workspace, 1.f, 1.f);
- _mm_kernel = std::move(k);
- _workspace.allocator()->allocate();
- }
- else if(1) // Generic v8a kernel
- {
- switch(a->info()->data_type())
+ case DataType::U8:
{
- case DataType::S8:
- {
- // Configure matrix multiply kernel
- GemmInterleaved<gemm_s8_4x4, int8_t, int32_t> gemm(&ci, M, N, K, false, false);
- _workspace.allocator()->init(TensorInfo(TensorShape{ (gemm.get_working_size() + workspace_alignment - 1) * NEScheduler::get().num_threads() }, 1, DataType::U8));
- }
+ run_optimised = setup_assembly_kernel(a, b, nullptr, output, 1.f, 1.f, _workspace, _memory_group, _asm_glue_unsigned);
break;
- case DataType::U8:
- {
- // Configure matrix multiply kernel
- GemmInterleaved<gemm_u8_4x4, uint8_t, uint32_t> gemm(&ci, M, N, K, false, false);
- _workspace.allocator()->init(TensorInfo(TensorShape{ (gemm.get_working_size() + workspace_alignment - 1) * NEScheduler::get().num_threads() }, 1, DataType::U8));
- }
+ }
+ default:
+ {
+ ARM_COMPUTE_ERROR("Datatype not supported");
break;
- default:
- ARM_COMPUTE_ERROR("Datatype not supported");
}
- _memory_group.manage(&_workspace);
- // Configure matrix multiplication kernel
- auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpAArch64Kernel>();
- k->configure(a, b, output, &_workspace, 1.f, 1.f);
- _mm_kernel = std::move(k);
- _workspace.allocator()->allocate();
}
- else
-#endif /* ARM_COMPUTE_AARCH64_V8_2 */
+#endif /* __aarch64__ */
+ if(!run_optimised)
{
// The interleaved output matrix will have the following shape: [ a_height * 4, ceil(a_width / 4.0f) ]
TensorShape shape_tmp_a = a->info()->tensor_shape();
@@ -206,7 +132,18 @@ void NEGEMMLowpAssemblyMatrixMultiplyCore::run()
NEScheduler::get().schedule(_mtx_b_reshape_kernel.get(), Window::DimY);
}
- NEScheduler::get().schedule(_mm_kernel.get(), Window::DimY);
+ if(_asm_glue_unsigned._optimised_kernel != nullptr)
+ {
+ _asm_glue_unsigned.run();
+ }
+ else if(_asm_glue_signed._optimised_kernel != nullptr)
+ {
+ _asm_glue_signed.run();
+ }
+ else
+ {
+ NEScheduler::get().schedule(_mm_kernel.get(), Window::DimY);
+ }
_memory_group.release();
}
diff --git a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
index ad47593f20..dc4ed5cefb 100644
--- a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
@@ -26,11 +26,9 @@
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/ITensor.h"
-#include "arm_compute/core/NEON/kernels/NEGEMMAssemblyBaseKernel.h"
#include "arm_compute/core/NEON/kernels/NEGEMMInterleave4x4Kernel.h"
#include "arm_compute/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.h"
#include "arm_compute/core/NEON/kernels/NEGEMMTranspose1xWKernel.h"
-#include "arm_compute/core/NEON/kernels/arm64/NEGEMMLowpAArch64V8P4Kernel.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Validate.h"
@@ -39,18 +37,13 @@
#include "arm_compute/runtime/TensorAllocator.h"
#include "support/ToolchainSupport.h"
-namespace arm_compute
-{
-#include "arm_compute/core/NEON/kernels/assembly/gemm_interleaved.hpp"
-#include "arm_compute/core/NEON/kernels/assembly/kernels/a64_gemm_u8_12x8.hpp"
-} // namespace arm_compute
-
using namespace arm_compute;
using namespace arm_compute::misc::shape_calculator;
NEGEMMLowpMatrixMultiplyCore::NEGEMMLowpMatrixMultiplyCore(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(std::move(memory_manager)), _mm_kernel(nullptr), _mtx_a_reshape_kernel(nullptr), _mtx_b_reshape_kernel(nullptr), _mtx_a_reduction_kernel(), _mtx_b_reduction_kernel(),
- _offset_contribution_kernel(), _vector_sum_col(), _vector_sum_row(), _tmp_a(), _tmp_b(), _workspace(), _a_offset(0), _b_offset(0), _run_vector_matrix_multiplication(false), _dot_product_path(false)
+ : _memory_group(std::move(memory_manager)), _asm_glue_unsigned(), _asm_glue_signed(), _mm_kernel(nullptr), _mtx_a_reshape_kernel(nullptr), _mtx_b_reshape_kernel(nullptr), _mtx_a_reduction_kernel(),
+ _mtx_b_reduction_kernel(), _offset_contribution_kernel(), _vector_sum_col(), _vector_sum_row(), _tmp_a(), _tmp_b(), _workspace(), _a_offset(0), _b_offset(0), _run_vector_matrix_multiplication(false),
+ _dot_product_path(false)
{
}
@@ -64,33 +57,27 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b,
_b_offset = b->info()->quantization_info().offset;
_run_vector_matrix_multiplication = a->info()->dimension(1) < 2;
-#ifdef ARM_COMPUTE_AARCH64_V8_2
- // Check for DOT product instruction
- const struct CPUInfo ci = NEScheduler::get().cpu_info();
- const int cpu_has_dotprod = static_cast<int>(ci.CPU) & static_cast<int>(CPUTarget::DOT);
-
- if(cpu_has_dotprod != 0)
+#ifdef __aarch64__
+ switch(a->info()->data_type())
{
- _dot_product_path = true;
-
- // Configure matrix multiply kernel
- struct CPUInfo ci = NEScheduler::get().cpu_info();
- const int M = output->info()->tensor_shape().y();
- const int N = output->info()->tensor_shape().x();
- const int K = a->info()->tensor_shape().x();
-
- const size_t workbench_size = GemmInterleaved<gemm_u8_12x8, gemm_u8_12x8::operand_type, gemm_u8_12x8::result_type>(&ci, M, N, K, false, false).get_working_size();
- constexpr size_t alignment = 4096;
- _workspace.allocator()->init(TensorInfo(TensorShape{ (workbench_size + alignment - 1) * NEScheduler::get().num_threads() }, 1, DataType::U8));
- _memory_group.manage(&_workspace);
-
- // Configure matrix multiplication kernel
- auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpAArch64V8P4Kernel>();
- k->configure(a, b, output, &_workspace, 1.f, 1.f, false, false);
- _mm_kernel = std::move(k);
+ case DataType::S8:
+ {
+ _dot_product_path = setup_assembly_kernel(a, b, nullptr, output, 1.f, 1.f, _workspace, _memory_group, _asm_glue_signed);
+ break;
+ }
+ case DataType::U8:
+ {
+ _dot_product_path = setup_assembly_kernel(a, b, nullptr, output, 1.f, 1.f, _workspace, _memory_group, _asm_glue_unsigned);
+ break;
+ }
+ default:
+ {
+ ARM_COMPUTE_ERROR("Datatype not supported");
+ break;
+ }
}
- else
-#endif /* ARM_COMPUTE_AARCH64_V8_2 */
+#endif /* __aarch64__ */
+ if(!_dot_product_path)
{
if(_run_vector_matrix_multiplication)
{
@@ -203,42 +190,28 @@ Status NEGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso
int32_t b_offset = b->quantization_info().offset;
bool run_vector_matrix_multiplication = a->dimension(1) < 2;
-#ifdef ARM_COMPUTE_AARCH64_V8_2
- // Check for DOT product instruction
- const struct CPUInfo ci = NEScheduler::get().cpu_info();
- const int cpu_has_dotprod = static_cast<int>(ci.CPU) & static_cast<int>(CPUTarget::DOT);
-
- if(cpu_has_dotprod != 0)
+ if(!run_vector_matrix_multiplication)
{
- // Validate matrix multiply kernel
- ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpAArch64V8P4Kernel::validate(a, b, output));
+ // The interleaved output matrix will have the following shape: [ a_height * 4, ceil(a_width / 4.0f) ]
+ TensorShape shape_tmp_a = a->tensor_shape();
+ shape_tmp_a.set(0, a->dimension(0) * 4);
+ shape_tmp_a.set(1, std::ceil(a->dimension(1) / 4.f));
+
+ // The transpose1xW output matrix will have the following shape: [ b_height * 16, ceil(b_width / 16.0f) ]
+ TensorShape shape_tmp_b = b->tensor_shape();
+ shape_tmp_b.set(0, b->dimension(1) * 16);
+ shape_tmp_b.set(1, std::ceil(b->dimension(0) / 16.f));
+
+ TensorInfo info_a(shape_tmp_a, 1, a->data_type());
+ TensorInfo info_b(shape_tmp_b, 1, b->data_type());
+
+ ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMInterleave4x4Kernel::validate(a, &info_a));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMTranspose1xWKernel::validate(b, &info_b));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixMultiplyKernel::validate(&info_a, &info_b, output));
}
else
-#endif /* ARM_COMPUTE_AARCH64_V8_2 */
{
- if(!run_vector_matrix_multiplication)
- {
- // The interleaved output matrix will have the following shape: [ a_height * 4, ceil(a_width / 4.0f) ]
- TensorShape shape_tmp_a = a->tensor_shape();
- shape_tmp_a.set(0, a->dimension(0) * 4);
- shape_tmp_a.set(1, std::ceil(a->dimension(1) / 4.f));
-
- // The transpose1xW output matrix will have the following shape: [ b_height * 16, ceil(b_width / 16.0f) ]
- TensorShape shape_tmp_b = b->tensor_shape();
- shape_tmp_b.set(0, b->dimension(1) * 16);
- shape_tmp_b.set(1, std::ceil(b->dimension(0) / 16.f));
-
- TensorInfo info_a(shape_tmp_a, 1, a->data_type());
- TensorInfo info_b(shape_tmp_b, 1, b->data_type());
-
- ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMInterleave4x4Kernel::validate(a, &info_a));
- ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMTranspose1xWKernel::validate(b, &info_b));
- ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixMultiplyKernel::validate(&info_a, &info_b, output));
- }
- else
- {
- ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixMultiplyKernel::validate(a, b, output));
- }
+ ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixMultiplyKernel::validate(a, b, output));
}
TensorInfo info_vector_sum_col, info_vector_sum_row;
@@ -288,7 +261,18 @@ void NEGEMMLowpMatrixMultiplyCore::run()
}
}
- NEScheduler::get().schedule(_mm_kernel.get(), Window::DimY);
+ if(_asm_glue_unsigned._optimised_kernel != nullptr)
+ {
+ _asm_glue_unsigned.run();
+ }
+ else if(_asm_glue_signed._optimised_kernel != nullptr)
+ {
+ _asm_glue_signed.run();
+ }
+ else
+ {
+ NEScheduler::get().schedule(_mm_kernel.get(), Window::DimY);
+ }
// Run matrix A reduction kernel only if _b_offset is not equal to 0
if(_b_offset != 0)