From fef6dae9c2cfe1003ab2abe3a41255e849b1b5eb Mon Sep 17 00:00:00 2001 From: Pablo Tello Date: Fri, 15 Dec 2017 10:36:21 +0000 Subject: COMPMID-750: Enabled support for U8 and S8 datatypes in NEGEMMLowpAArch64V8P4Kernel Change-Id: If32cbdc65f2e1441595cae5b4824a9b4357c8bf6 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/113467 Tested-by: Jenkins Reviewed-by: Anthony Barbier Reviewed-by: Georgios Pinitas --- .../kernels/arm64/NEGEMMLowpAArch64V8P4Kernel.cpp | 73 +++++++++++++++------- .../NEGEMMLowpAssemblyMatrixMultiplyCore.cpp | 6 +- 2 files changed, 53 insertions(+), 26 deletions(-) diff --git a/src/core/NEON/kernels/arm64/NEGEMMLowpAArch64V8P4Kernel.cpp b/src/core/NEON/kernels/arm64/NEGEMMLowpAArch64V8P4Kernel.cpp index 099934a49d..7827bc1ccf 100644 --- a/src/core/NEON/kernels/arm64/NEGEMMLowpAArch64V8P4Kernel.cpp +++ b/src/core/NEON/kernels/arm64/NEGEMMLowpAArch64V8P4Kernel.cpp @@ -38,6 +38,7 @@ namespace arm_compute { #include "arm_compute/core/NEON/kernels/assembly/gemm_interleaved.hpp" +#include "arm_compute/core/NEON/kernels/assembly/kernels/a64_gemm_s8_12x8.hpp" #include "arm_compute/core/NEON/kernels/assembly/kernels/a64_gemm_u8_12x8.hpp" } // namespace arm_compute @@ -54,7 +55,7 @@ using namespace arm_compute; Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output) { - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::QASYMM8); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::QASYMM8, DataType::U8, DataType::S8); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32); ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1); @@ -80,6 +81,38 @@ std::pair validate_and_configure_window(ITensorInfo *input0, ITe Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{}; return std::make_pair(err, win); } + +template +void *align_workspace(GemmInterleaved &gemm, const ThreadInfo &info, ITensor *ws) +{ + constexpr size_t alignment = 4096; + const size_t offset = (gemm.get_working_size() + alignment - 1) * info.thread_id; + void *workspace = ws->buffer() + offset; + size_t workspace_size = ws->info()->total_size(); + + if(support::cpp11::align(alignment, gemm.get_working_size(), workspace, workspace_size) == nullptr) + { + ARM_COMPUTE_ERROR("Not enough space to align buffer!"); + } + return workspace; +} + +template +void execute_gemm(const Window &win, Iterator &in0, Iterator &in1, Iterator &out, + const ThreadInfo &info, ITensor *ws, int M, int N, int K, bool is_transposed_0, bool is_transposed_1, + int lda, int ldb, int ldc, float alpha, float beta) +{ + GemmInterleaved gemm(&info.cpu_info, M, N, K, is_transposed_0, is_transposed_1); + void *workspace = align_workspace(gemm, info, ws); + execute_window_loop(win, [&](const Coordinates & id) + { + gemm.execute(reinterpret_cast(in0.ptr()), lda, + reinterpret_cast(in1.ptr()), ldb, + reinterpret_cast(out.ptr()), ldc, + alpha, beta, workspace); + }, + in0, out); +} } // namespace namespace arm_compute @@ -123,8 +156,6 @@ void NEGEMMLowpAArch64V8P4Kernel::run(const Window &window, const ThreadInfo &in const int ldb = _input1->info()->strides_in_bytes().y(); const int ldc = _output->info()->strides_in_bytes().y() / sizeof(uint32_t); - const auto in1_ptr = reinterpret_cast(_input1->buffer()); - const int M = std::min(_output->info()->tensor_shape().y(), static_cast(window.y().end())) - window.y().start(); const int N = _output->info()->tensor_shape().x(); const int K = _input0->info()->tensor_shape().x(); @@ -135,28 +166,28 @@ void NEGEMMLowpAArch64V8P4Kernel::run(const Window &window, const ThreadInfo &in win.set(1, Window::Dimension(0, 1, 1)); Iterator in0(_input0, window); + Iterator in1(_input1, window); Iterator out(_output, window); - GemmInterleaved gemm(&info.cpu_info, M, N, K, _is_transposed_0, _is_transposed_1); - - constexpr size_t alignment = 4096; - const size_t offset = (gemm.get_working_size() + alignment - 1) * info.thread_id; - void *workspace = _workspace->buffer() + offset; - size_t workspace_size = _workspace->info()->total_size(); - - if(support::cpp11::align(alignment, gemm.get_working_size(), workspace, workspace_size) == nullptr) + switch(_input0->info()->data_type()) { - ARM_COMPUTE_ERROR("Not enough space to align buffer!"); + case DataType::QASYMM8: + case DataType::U8: + { + execute_gemm(win, in0, in1, out, info, _workspace, M, N, K, _is_transposed_0, _is_transposed_1, lda, ldb, ldc, _alpha, _beta); + break; + } + case DataType::S8: + { + execute_gemm(win, in0, in1, out, info, _workspace, M, N, K, _is_transposed_0, _is_transposed_1, lda, ldb, ldc, _alpha, _beta); + break; + } + default: + { + ARM_COMPUTE_ERROR("Not supported."); + break; + } } - - execute_window_loop(win, [&](const Coordinates & id) - { - gemm.execute(reinterpret_cast(in0.ptr()), lda, - reinterpret_cast(in1_ptr), ldb, - reinterpret_cast(out.ptr()), ldc, - _alpha, _beta, workspace); - }, - in0, out); } } // namespace arm_compute #endif /* ARM_COMPUTE_AARCH64_V8_2 */ diff --git a/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp index 6e03ffa1bc..9b36e81afd 100644 --- a/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp +++ b/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp @@ -74,7 +74,7 @@ void NEGEMMLowpAssemblyMatrixMultiplyCore::configure(const ITensor *a, const ITe #endif /* __aarch64__ */ #ifdef ARM_COMPUTE_AARCH64_V8_2 - if(ci.CPU == CPUTarget::A75_DOT) + if(ci.CPU == CPUTarget::A75_DOT || ci.CPU == CPUTarget::A55_DOT) { // Configure matrix multiply kernel GemmInterleaved gemm(&ci, M, N, K, false, false); @@ -87,10 +87,6 @@ void NEGEMMLowpAssemblyMatrixMultiplyCore::configure(const ITensor *a, const ITe _mm_kernel = std::move(k); _workspace.allocator()->allocate(); } - else if(ci.CPU == CPUTarget::A55_DOT) - { - ARM_COMPUTE_ERROR_ON("WIP"); - } else #elif defined(ARM_COMPUTE_AARCH64_V8A) if(ci.CPU == CPUTarget::A53) -- cgit v1.2.1