diff options
Diffstat (limited to 'src/core/NEON/kernels/arm64/NEGEMMLowpAArch64V8P4Kernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/arm64/NEGEMMLowpAArch64V8P4Kernel.cpp | 73 |
1 files changed, 52 insertions, 21 deletions
diff --git a/src/core/NEON/kernels/arm64/NEGEMMLowpAArch64V8P4Kernel.cpp b/src/core/NEON/kernels/arm64/NEGEMMLowpAArch64V8P4Kernel.cpp index 099934a49d..7827bc1ccf 100644 --- a/src/core/NEON/kernels/arm64/NEGEMMLowpAArch64V8P4Kernel.cpp +++ b/src/core/NEON/kernels/arm64/NEGEMMLowpAArch64V8P4Kernel.cpp @@ -38,6 +38,7 @@ namespace arm_compute { #include "arm_compute/core/NEON/kernels/assembly/gemm_interleaved.hpp" +#include "arm_compute/core/NEON/kernels/assembly/kernels/a64_gemm_s8_12x8.hpp" #include "arm_compute/core/NEON/kernels/assembly/kernels/a64_gemm_u8_12x8.hpp" } // namespace arm_compute @@ -54,7 +55,7 @@ using namespace arm_compute; Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output) { - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::QASYMM8); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::QASYMM8, DataType::U8, DataType::S8); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32); ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1); @@ -80,6 +81,38 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITe Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{}; return std::make_pair(err, win); } + +template <typename strategy, typename To, typename Tr> +void *align_workspace(GemmInterleaved<strategy, To, Tr> &gemm, const ThreadInfo &info, ITensor *ws) +{ + constexpr size_t alignment = 4096; + const size_t offset = (gemm.get_working_size() + alignment - 1) * info.thread_id; + void *workspace = ws->buffer() + offset; + size_t workspace_size = ws->info()->total_size(); + + if(support::cpp11::align(alignment, gemm.get_working_size(), workspace, workspace_size) == nullptr) + { + ARM_COMPUTE_ERROR("Not enough space to align buffer!"); + } + return workspace; +} + +template <typename strategy> +void execute_gemm(const Window &win, Iterator &in0, Iterator &in1, Iterator &out, + const ThreadInfo &info, ITensor *ws, int M, int N, int K, bool is_transposed_0, bool is_transposed_1, + int lda, int ldb, int ldc, float alpha, float beta) +{ + GemmInterleaved<strategy, typename strategy::operand_type, typename strategy::result_type> gemm(&info.cpu_info, M, N, K, is_transposed_0, is_transposed_1); + void *workspace = align_workspace(gemm, info, ws); + execute_window_loop(win, [&](const Coordinates & id) + { + gemm.execute(reinterpret_cast<const typename strategy::operand_type *>(in0.ptr()), lda, + reinterpret_cast<const typename strategy::operand_type *>(in1.ptr()), ldb, + reinterpret_cast<typename strategy::result_type *>(out.ptr()), ldc, + alpha, beta, workspace); + }, + in0, out); +} } // namespace namespace arm_compute @@ -123,8 +156,6 @@ void NEGEMMLowpAArch64V8P4Kernel::run(const Window &window, const ThreadInfo &in const int ldb = _input1->info()->strides_in_bytes().y(); const int ldc = _output->info()->strides_in_bytes().y() / sizeof(uint32_t); - const auto in1_ptr = reinterpret_cast<const gemm_u8_12x8::operand_type *>(_input1->buffer()); - const int M = std::min(_output->info()->tensor_shape().y(), static_cast<size_t>(window.y().end())) - window.y().start(); const int N = _output->info()->tensor_shape().x(); const int K = _input0->info()->tensor_shape().x(); @@ -135,28 +166,28 @@ void NEGEMMLowpAArch64V8P4Kernel::run(const Window &window, const ThreadInfo &in win.set(1, Window::Dimension(0, 1, 1)); Iterator in0(_input0, window); + Iterator in1(_input1, window); Iterator out(_output, window); - GemmInterleaved<gemm_u8_12x8, gemm_u8_12x8::operand_type, gemm_u8_12x8::result_type> gemm(&info.cpu_info, M, N, K, _is_transposed_0, _is_transposed_1); - - constexpr size_t alignment = 4096; - const size_t offset = (gemm.get_working_size() + alignment - 1) * info.thread_id; - void *workspace = _workspace->buffer() + offset; - size_t workspace_size = _workspace->info()->total_size(); - - if(support::cpp11::align(alignment, gemm.get_working_size(), workspace, workspace_size) == nullptr) + switch(_input0->info()->data_type()) { - ARM_COMPUTE_ERROR("Not enough space to align buffer!"); + case DataType::QASYMM8: + case DataType::U8: + { + execute_gemm<gemm_u8_12x8>(win, in0, in1, out, info, _workspace, M, N, K, _is_transposed_0, _is_transposed_1, lda, ldb, ldc, _alpha, _beta); + break; + } + case DataType::S8: + { + execute_gemm<gemm_s8_12x8>(win, in0, in1, out, info, _workspace, M, N, K, _is_transposed_0, _is_transposed_1, lda, ldb, ldc, _alpha, _beta); + break; + } + default: + { + ARM_COMPUTE_ERROR("Not supported."); + break; + } } - - execute_window_loop(win, [&](const Coordinates & id) - { - gemm.execute(reinterpret_cast<const gemm_u8_12x8::operand_type *>(in0.ptr()), lda, - reinterpret_cast<const gemm_u8_12x8::operand_type *>(in1_ptr), ldb, - reinterpret_cast<gemm_u8_12x8::result_type *>(out.ptr()), ldc, - _alpha, _beta, workspace); - }, - in0, out); } } // namespace arm_compute #endif /* ARM_COMPUTE_AARCH64_V8_2 */ |