From 71d9b57aac146ae3ad5648c1308a872cea90070d Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Fri, 6 Jul 2018 17:05:59 +0100 Subject: COMPMID-1381: Cleaned up the AssemblyHelper interface Introduced a new IFunction for when we'll fork the arm_gemm functions Increased encapsulation and abstraction of which method is used Change-Id: I5fd8b14b5c77e7f8ecb09029b5e2eccd10dbdcf4 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/139108 Tested-by: Jenkins Reviewed-by: Georgios Pinitas Reviewed-by: Pablo Tello --- src/runtime/NEON/functions/NEConvolutionLayer.cpp | 1 + src/runtime/NEON/functions/NEGEMM.cpp | 21 +- .../NEON/functions/NEGEMMAssemblyDispatch.cpp | 252 +++++++++++++++++++++ .../NEON/functions/NEGEMMConvolutionLayer.cpp | 13 +- .../NEGEMMLowpAssemblyMatrixMultiplyCore.cpp | 16 +- .../functions/NEGEMMLowpMatrixMultiplyCore.cpp | 26 +-- .../NEON/functions/NEWinogradConvolutionLayer.cpp | 10 +- 7 files changed, 298 insertions(+), 41 deletions(-) create mode 100644 src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp (limited to 'src/runtime/NEON') diff --git a/src/runtime/NEON/functions/NEConvolutionLayer.cpp b/src/runtime/NEON/functions/NEConvolutionLayer.cpp index 4018407153..d71fd5a715 100644 --- a/src/runtime/NEON/functions/NEConvolutionLayer.cpp +++ b/src/runtime/NEON/functions/NEConvolutionLayer.cpp @@ -26,6 +26,7 @@ #include "arm_compute/core/PixelValue.h" #include "arm_compute/core/Utils.h" #include "arm_compute/core/Validate.h" +#include "arm_compute/runtime/NEON/NEScheduler.h" #include "support/ToolchainSupport.h" #include diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp index 795ffc5d1c..c958904b93 100644 --- a/src/runtime/NEON/functions/NEGEMM.cpp +++ b/src/runtime/NEON/functions/NEGEMM.cpp @@ -29,8 +29,8 @@ #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/Validate.h" -#include "arm_compute/runtime/NEON/AssemblyHelper.h" #include "arm_compute/runtime/NEON/NEScheduler.h" +#include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h" #include "arm_compute/runtime/TensorAllocator.h" #include "support/ToolchainSupport.h" @@ -39,8 +39,8 @@ namespace arm_compute { NEGEMM::NEGEMM(std::shared_ptr memory_manager) - : _memory_group(std::move(memory_manager)), _interleave_kernel(), _transpose_kernel(), _mm_kernel(), _asm_glue(), _ma_kernel(), _tmp_a(), _tmp_b(), _workspace(), _B_pretransposed(), - _original_b(nullptr), _run_vector_matrix_multiplication(false), _run_addition(false), _reshape_b_only_on_first_run(false), _is_prepared(false) + : _memory_group(memory_manager), _interleave_kernel(), _transpose_kernel(), _mm_kernel(), _asm_glue(memory_manager), _ma_kernel(), _tmp_a(), _tmp_b(), _original_b(nullptr), + _run_vector_matrix_multiplication(false), _run_addition(false), _reshape_b_only_on_first_run(false), _is_prepared(false) { } @@ -67,10 +67,13 @@ void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITe _reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run(); _run_vector_matrix_multiplication = a->info()->dimension(1) < 2; _original_b = b; - _asm_glue._optimised_kernel = nullptr; - const bool run_optimised = a->info()->data_type() == DataType::F32 && (c == nullptr || beta == 0.f) - && setup_assembly_kernel(a, b, d, alpha, beta, _reshape_b_only_on_first_run, _workspace, _B_pretransposed, _memory_group, _asm_glue); + bool run_optimised = a->info()->data_type() == DataType::F32 && (c == nullptr || beta == 0.f); + if(run_optimised) + { + _asm_glue.configure(a, b, d, alpha, beta, _reshape_b_only_on_first_run); + run_optimised = _asm_glue.is_configured(); + } // Check if the first input tensor is a vector. // If so, all the kernels for reshaping the tensors can be skipped @@ -150,7 +153,7 @@ void NEGEMM::run() { prepare(); - if(_asm_glue._optimised_kernel != nullptr) + if(_asm_glue.is_configured()) { _memory_group.acquire(); _asm_glue.run(); @@ -188,14 +191,14 @@ void NEGEMM::prepare() { if(!_is_prepared) { - if(_asm_glue._optimised_kernel) + if(_asm_glue.is_configured()) { ARM_COMPUTE_ERROR_ON(!_original_b->is_used()); _asm_glue.prepare(); _original_b->mark_as_unused(); } - else if(_reshape_b_only_on_first_run && !_run_vector_matrix_multiplication && !_asm_glue._optimised_kernel) + else if(_reshape_b_only_on_first_run && !_run_vector_matrix_multiplication && !_asm_glue.is_configured()) { ARM_COMPUTE_ERROR_ON(!_original_b->is_used()); diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp new file mode 100644 index 0000000000..f6111a31bc --- /dev/null +++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp @@ -0,0 +1,252 @@ +/* + * Copyright (c) 2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h" + +#include "arm_compute/runtime/NEON/NEScheduler.h" + +using namespace arm_compute; + +template +NEGEMMAssemblyDispatch::NEGEMMAssemblyDispatch(std::shared_ptr memory_manager) + : _function(nullptr), _arm_gemm(), _memory_group(std::move(memory_manager)) +{ +} + +template +void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint) +{ + //TODO(antbar01) Check heuristics here to figure out if we should use an ACL IFunction + _arm_gemm.configure(a, b, d, alpha, beta, pretranspose_hint, _memory_group); +} + +template +void NEGEMMAssemblyDispatch::prepare() +{ + if(_function != nullptr) + { + _function->prepare(); + } + else + { + _arm_gemm.prepare(); + } +} + +template +bool NEGEMMAssemblyDispatch::is_configured() const +{ + return _arm_gemm.is_configured() || _function != nullptr; +} + +template +void NEGEMMAssemblyDispatch::run() +{ + _memory_group.acquire(); + if(_function != nullptr) + { + _function->run(); + } + else + { + _arm_gemm.run(); + } + _memory_group.release(); +} + +#ifndef __aarch64__ +template <> +void NEGEMMAssemblyDispatch::Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, MemoryGroup &memory_group) +{ + // arm_gemm::gemm for 8bit only exists for aarch64 + ARM_COMPUTE_UNUSED(a); + ARM_COMPUTE_UNUSED(b); + ARM_COMPUTE_UNUSED(d); + ARM_COMPUTE_UNUSED(alpha); + ARM_COMPUTE_UNUSED(beta); + ARM_COMPUTE_UNUSED(pretranspose_hint); + ARM_COMPUTE_UNUSED(memory_group); +} + +template <> +void NEGEMMAssemblyDispatch::Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, MemoryGroup &memory_group) +{ + // arm_gemm::gemm for 8bit only exists for aarch64 + ARM_COMPUTE_UNUSED(a); + ARM_COMPUTE_UNUSED(b); + ARM_COMPUTE_UNUSED(d); + ARM_COMPUTE_UNUSED(alpha); + ARM_COMPUTE_UNUSED(beta); + ARM_COMPUTE_UNUSED(pretranspose_hint); + ARM_COMPUTE_UNUSED(memory_group); +} +#endif // aarch64 +template +void NEGEMMAssemblyDispatch::Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, MemoryGroup &memory_group) +{ + const CPUInfo &ci = NEScheduler::get().cpu_info(); + const int M = d->info()->tensor_shape().y(); + const int N = d->info()->tensor_shape().x(); + const int K = a->info()->tensor_shape().x(); + const int batches = d->info()->tensor_shape().total_size_upper(2); + const int multis = b->info()->tensor_shape().z(); + unsigned int num_threads = NEScheduler::get().num_threads(); + + _gemm_kernel_asm = arm_gemm::gemm(ci, M, N, K, batches, multis, false, false, alpha, beta, num_threads, pretranspose_hint); + if(_gemm_kernel_asm == nullptr) + { + //configuration not supported: Leave function unconfigured: + return; + } + + // arm_compute wrapper for the Gemm object (see above) + std::unique_ptr> acl_gemm_wrapper = support::cpp14::make_unique>(); + ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr); + acl_gemm_wrapper->configure(_gemm_kernel_asm.get()); + const size_t workspace_size = _gemm_kernel_asm->get_working_size(); + if(workspace_size > 0) + { + // Allocate workspace + const unsigned int alignment = 4096; + //FIXME: is memory_group ever null ? + allocate_workspace(workspace_size, &memory_group, alignment); + } + + //if we disable this code below in brackets then ConvLayer deadlocks when threads > 1 and + //the shapes are In=1x1x1024 Weights=1x1x1024x1001 Biases=1001 Out=1x1x1001 + { + const unsigned int window_size = _gemm_kernel_asm->get_window_size(); + if(window_size < num_threads) + { + num_threads = window_size; + _gemm_kernel_asm->set_nthreads(num_threads); + } + } + + _optimised_kernel = std::move(acl_gemm_wrapper); + _a = a; + _b = b; + _d = d; + // Check for pre-transposed support + if(_gemm_kernel_asm->B_pretranspose_required()) + { + // Forcing 128-byte alignment (required by 32-bit kernels) + const unsigned int alignment = 128; + const size_t B_pretranspose_size = _gemm_kernel_asm->get_B_pretransposed_array_size(); + _pretranspose.allocator()->init(TensorInfo(TensorShape{ (B_pretranspose_size + alignment /* FIXME: remove alignment after COMPMID-1088 */) }, 1, DataType::S8), alignment); + _pretranspose.allocator()->allocate(); + ARM_COMPUTE_ERROR_ON_NULLPTR(_pretranspose.buffer()); + } +} + +template +void NEGEMMAssemblyDispatch::Fallback::prepare() +{ + if(!_is_prepared) + { + // Pretranspose B if required + if(_gemm_kernel_asm->B_pretranspose_required()) + { + const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput); + const auto in1_ptr = reinterpret_cast(_b->buffer()); + const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput); + + ARM_COMPUTE_ERROR_ON(_pretranspose.buffer() == nullptr); + _gemm_kernel_asm->pretranspose_B_array(_pretranspose.buffer(), in1_ptr, ldb, multi_stride_b); + _b->mark_as_unused(); + } + + _is_prepared = true; + } +} + +template +void NEGEMMAssemblyDispatch::Fallback::allocate_workspace(size_t workspace_size, MemoryGroup *memory_group, size_t alignment) +{ + ARM_COMPUTE_ERROR_ON_MSG(workspace_size == 0, "size cannot be 0"); + _workspace.allocator()->init(TensorInfo(TensorShape{ (workspace_size + alignment /* FIXME: remove alignment after COMPMID-1088 */) }, 1, DataType::S8), alignment); + if(memory_group != nullptr) + { + memory_group->manage(&_workspace); + } + _workspace.allocator()->allocate(); +} + +template +bool NEGEMMAssemblyDispatch::Fallback::is_configured() const +{ + return _optimised_kernel != nullptr; +} + +template +void NEGEMMAssemblyDispatch::Fallback::run() +{ + const int lda = _a->info()->strides_in_bytes().y() / sizeof(TypeInput); + const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput); + const int ldd = _d->info()->strides_in_bytes().y() / sizeof(TypeOutput); + + // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is + // the relevant multiple of the row stride. + const bool is_nhwc = _a->info()->data_layout() == DataLayout::NHWC; + const int stride_in_bytes_a = is_nhwc ? _a->info()->strides_in_bytes().y() * _d->info()->dimension(1) : _a->info()->strides_in_bytes().z(); + + const int batch_stride_a = stride_in_bytes_a / sizeof(TypeInput); + const int batch_stride_d = _d->info()->strides_in_bytes().z() / sizeof(TypeOutput); + + const int multi_stride_a = _a->info()->strides_in_bytes()[3] / sizeof(TypeInput); + const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput); + const int multi_stride_d = _d->info()->strides_in_bytes()[3] / sizeof(TypeOutput); + + const auto in0_ptr = reinterpret_cast(_a->buffer()); + const auto in1_ptr = reinterpret_cast(_b->buffer()); + auto out_ptr = reinterpret_cast(_d->buffer()); + + // Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads + if(_workspace.buffer() != nullptr) + { + _gemm_kernel_asm->set_working_space(reinterpret_cast(_workspace.buffer())); + const unsigned int window_size = _gemm_kernel_asm->get_window_size(); + unsigned int num_threads = NEScheduler::get().num_threads(); + if(window_size < num_threads) + { + num_threads = window_size; + _gemm_kernel_asm->set_nthreads(num_threads); + } + } + + // Prepare assembly kernel + prepare(); + + // Set gemm parameters + _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, in1_ptr, ldb, multi_stride_b, out_ptr, ldd, batch_stride_d, multi_stride_d); + + // Schedule assembly kernel + NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX); +} + +namespace arm_compute +{ +template class NEGEMMAssemblyDispatch; +template class NEGEMMAssemblyDispatch; +template class NEGEMMAssemblyDispatch; +} //namespace arm_compute diff --git a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp index 94ef4e7b32..25e8d9e60b 100644 --- a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp +++ b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp @@ -224,9 +224,9 @@ Status validate_and_initialize_values(const ITensorInfo *input, const ITensorInf } // namespace NEGEMMConvolutionLayer::NEGEMMConvolutionLayer(const std::shared_ptr &memory_manager) - : _asm_glue(), _memory_group(memory_manager), _input_im2col_kernel(), _input_interleave_kernel(), _reshape_weights(), _mm_kernel(), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(), - _output_col2im_kernel(), _activationlayer_function(), _add_bias_kernel(), _original_weights(nullptr), _input_im2col_reshaped(), _input_interleaved_reshaped(), _weights_reshaped(), _gemm_output(), - _tmp_output(), _workspace(), _B_pretransposed(), _data_layout(DataLayout::NCHW), _append_bias(false), _is_fully_connected_convolution(false), _are_weights_reshaped(false), _is_quantized(false), + : _memory_group(memory_manager), _asm_glue(memory_manager), _input_im2col_kernel(), _input_interleave_kernel(), _reshape_weights(), _mm_kernel(), _mm_gemmlowp(memory_manager), + _gemmlowp_output_stage(), _output_col2im_kernel(), _activationlayer_function(), _add_bias_kernel(), _original_weights(nullptr), _input_im2col_reshaped(), _input_interleaved_reshaped(), + _weights_reshaped(), _gemm_output(), _tmp_output(), _data_layout(DataLayout::NCHW), _append_bias(false), _is_fully_connected_convolution(false), _are_weights_reshaped(false), _is_quantized(false), _is_interleaved(false), _is_activationlayer_enabled(false), _skip_im2col(false), _is_prepared(false) { } @@ -384,7 +384,8 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig // Configure matrix multiply if(run_optimised) { - if(!setup_assembly_kernel(_skip_im2col ? input : &_input_im2col_reshaped, weights, is_nhwc ? output : &_gemm_output, 1.f, 0.f, true, _workspace, _B_pretransposed, _memory_group, _asm_glue)) + _asm_glue.configure(_skip_im2col ? input : &_input_im2col_reshaped, weights, is_nhwc ? output : &_gemm_output, 1.f, 0.f, true); + if(!_asm_glue.is_configured()) { ARM_COMPUTE_ERROR("setup_assembly_kernel failed."); } @@ -587,7 +588,7 @@ void NEGEMMConvolutionLayer::run() } // Runs matrix multiply on reshaped matrices - if(_asm_glue._optimised_kernel != nullptr) + if(_asm_glue.is_configured()) { _asm_glue.run(); } @@ -652,7 +653,7 @@ void NEGEMMConvolutionLayer::prepare() } // Run GEMM prepare stage - if(_asm_glue._optimised_kernel) + if(_asm_glue.is_configured()) { _asm_glue.prepare(); } diff --git a/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp index 98b476794d..9b5d02ca44 100644 --- a/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp +++ b/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp @@ -38,8 +38,8 @@ using namespace arm_compute; NEGEMMLowpAssemblyMatrixMultiplyCore::NEGEMMLowpAssemblyMatrixMultiplyCore(std::shared_ptr memory_manager) - : _memory_group(std::move(memory_manager)), _asm_glue_unsigned(), _asm_glue_signed(), _mm_kernel(nullptr), _mtx_a_reshape_kernel(nullptr), _mtx_b_reshape_kernel(nullptr), _tmp_a(), _tmp_b(), - _workspace(), _B_pretransposed() + : _memory_group(memory_manager), _asm_glue_unsigned(memory_manager), _asm_glue_signed(memory_manager), _mm_kernel(nullptr), _mtx_a_reshape_kernel(nullptr), _mtx_b_reshape_kernel(nullptr), _tmp_a(), + _tmp_b() { } @@ -53,18 +53,19 @@ void NEGEMMLowpAssemblyMatrixMultiplyCore::configure(const ITensor *a, const ITe ARM_COMPUTE_ERROR_ON_MSG((b)->info()->dimension(0) != (output)->info()->dimension(0), "The output matrix must have the same number of columns as the matrix B"); bool run_optimised = false; -#ifdef __aarch64__ switch(a->info()->data_type()) { case DataType::S8: { - run_optimised = setup_assembly_kernel(a, b, output, 1.f, 0.f, true, _workspace, _B_pretransposed, _memory_group, _asm_glue_signed); + _asm_glue_signed.configure(a, b, output, 1.f, 0.f, true); + run_optimised = _asm_glue_unsigned.is_configured(); break; } case DataType::QASYMM8: case DataType::U8: { - run_optimised = setup_assembly_kernel(a, b, output, 1.f, 0.f, true, _workspace, _B_pretransposed, _memory_group, _asm_glue_unsigned); + _asm_glue_unsigned.configure(a, b, output, 1.f, 0.f, true); + run_optimised = _asm_glue_unsigned.is_configured(); break; } default: @@ -73,7 +74,6 @@ void NEGEMMLowpAssemblyMatrixMultiplyCore::configure(const ITensor *a, const ITe break; } } -#endif /* __aarch64__ */ if(!run_optimised) { // The interleaved output matrix will have the following shape: [ a_height * 4, ceil(a_width / 4.0f) ] @@ -133,11 +133,11 @@ void NEGEMMLowpAssemblyMatrixMultiplyCore::run() NEScheduler::get().schedule(_mtx_b_reshape_kernel.get(), Window::DimY); } - if(_asm_glue_unsigned._optimised_kernel != nullptr) + if(_asm_glue_unsigned.is_configured()) { _asm_glue_unsigned.run(); } - else if(_asm_glue_signed._optimised_kernel != nullptr) + else if(_asm_glue_signed.is_configured()) { _asm_glue_signed.run(); } diff --git a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp index a92ffa7c7b..a57271c17c 100644 --- a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp +++ b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp @@ -41,8 +41,8 @@ using namespace arm_compute; using namespace arm_compute::misc::shape_calculator; NEGEMMLowpMatrixMultiplyCore::NEGEMMLowpMatrixMultiplyCore(std::shared_ptr memory_manager) - : _memory_group(std::move(memory_manager)), _asm_glue_unsigned(), _asm_glue_signed(), _mm_kernel(nullptr), _mtx_a_reshape_kernel(nullptr), _mtx_b_reshape_kernel(nullptr), _mtx_a_reduction_kernel(), - _mtx_b_reduction_kernel(), _offset_contribution_kernel(), _vector_sum_col(), _vector_sum_row(), _tmp_a(), _tmp_b(), _workspace(), _B_pretranspose(), _original_b(nullptr), _a_offset(0), _b_offset(0), + : _memory_group(memory_manager), _asm_glue_unsigned(memory_manager), _asm_glue_signed(memory_manager), _mm_kernel(nullptr), _mtx_a_reshape_kernel(nullptr), _mtx_b_reshape_kernel(nullptr), + _mtx_a_reduction_kernel(), _mtx_b_reduction_kernel(), _offset_contribution_kernel(), _vector_sum_col(), _vector_sum_row(), _tmp_a(), _tmp_b(), _original_b(nullptr), _a_offset(0), _b_offset(0), _run_vector_matrix_multiplication(false), _dot_product_path(false), _reshape_b_only_on_first_run(false), _is_prepared(false) { } @@ -53,10 +53,8 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b, ARM_COMPUTE_ERROR_THROW_ON(NEGEMMLowpMatrixMultiplyCore::validate(a->info(), b->info(), output->info(), gemm_info)); // Clear state - _mtx_a_reshape_kernel = nullptr; - _mtx_b_reshape_kernel = nullptr; - _asm_glue_signed._optimised_kernel = nullptr; - _asm_glue_unsigned._optimised_kernel = nullptr; + _mtx_a_reshape_kernel = nullptr; + _mtx_b_reshape_kernel = nullptr; // Set internal variables _a_offset = a->info()->quantization_info().offset; @@ -71,13 +69,15 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b, { case DataType::S8: { - _dot_product_path = setup_assembly_kernel(a, b, output, 1.f, 0.f, _reshape_b_only_on_first_run, _workspace, _B_pretranspose, _memory_group, _asm_glue_signed); + _asm_glue_signed.configure(a, b, output, 1.f, 0.f, _reshape_b_only_on_first_run); + _dot_product_path = _asm_glue_signed.is_configured(); break; } case DataType::QASYMM8: case DataType::U8: { - _dot_product_path = setup_assembly_kernel(a, b, output, 1.f, 0.f, _reshape_b_only_on_first_run, _workspace, _B_pretranspose, _memory_group, _asm_glue_unsigned); + _asm_glue_unsigned.configure(a, b, output, 1.f, 0.f, _reshape_b_only_on_first_run); + _dot_product_path = _asm_glue_unsigned.is_configured(); break; } default: @@ -275,11 +275,11 @@ void NEGEMMLowpMatrixMultiplyCore::run() } // Run GEMM - if(_asm_glue_unsigned._optimised_kernel != nullptr) + if(_asm_glue_unsigned.is_configured()) { _asm_glue_unsigned.run(); } - else if(_asm_glue_signed._optimised_kernel != nullptr) + else if(_asm_glue_signed.is_configured()) { _asm_glue_signed.run(); } @@ -311,15 +311,15 @@ void NEGEMMLowpMatrixMultiplyCore::prepare() if(!_is_prepared) { // Run assembly reshape - if((_asm_glue_signed._optimised_kernel || _asm_glue_signed._optimised_kernel) && _reshape_b_only_on_first_run) + if((_asm_glue_signed.is_configured() || _asm_glue_signed.is_configured()) && _reshape_b_only_on_first_run) { ARM_COMPUTE_ERROR_ON(!_original_b->is_used()); - if(_asm_glue_unsigned._optimised_kernel != nullptr) + if(_asm_glue_unsigned.is_configured()) { _asm_glue_unsigned.prepare(); } - else if(_asm_glue_signed._optimised_kernel != nullptr) + else if(_asm_glue_signed.is_configured()) { _asm_glue_signed.prepare(); } diff --git a/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp b/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp index 31e17e5cf9..a1d801e574 100644 --- a/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp +++ b/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp @@ -24,16 +24,15 @@ #include "arm_compute/runtime/NEON/functions/NEWinogradConvolutionLayer.h" #include "arm_compute/core/Error.h" +#include "arm_compute/core/NEON/kernels/NEWinogradConvolutionLayerKernel.h" #include "arm_compute/core/Utils.h" #include "arm_compute/core/Validate.h" #include "arm_compute/core/Validate.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" -#include "arm_compute/runtime/NEON/AssemblyHelper.h" #include "arm_compute/runtime/NEON/NEScheduler.h" +#include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h" #include "support/ToolchainSupport.h" -#include "arm_compute/core/NEON/kernels/NEWinogradConvolutionLayerKernel.h" - #include "arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp" namespace arm_compute @@ -292,7 +291,7 @@ void NEWinogradConvolutionLayer::configure(const ITensor *input, const ITensor * _arm_gemm->set_arrays(reinterpret_cast(_input_workspace.buffer()), input_matrix_row_stride, 0, input_matrix_stride, reinterpret_cast(_kernel_storage.buffer()), kernel_matrix_row_stride, kernel_matrix_stride, reinterpret_cast(_output_workspace.buffer()), output_matrix_row_stride, 0, output_matrix_stride); - auto acl_gemm_wrapper = support::cpp14::make_unique>>(); + auto acl_gemm_wrapper = support::cpp14::make_unique>(); acl_gemm_wrapper->configure(_arm_gemm.get()); const size_t workspace_size = _arm_gemm->get_working_size(); @@ -302,7 +301,8 @@ void NEWinogradConvolutionLayer::configure(const ITensor *input, const ITensor * const unsigned int alignment = 4096; // TODO (COMPMID-1248) : Add support for memory manager in NEWinogradConvolutionLayer // Warning : Do not set a memory group in allocate_workspace, should be done under COMPMID-1248 - allocate_workspace(workspace_size, _workspace, nullptr, alignment); + _workspace.allocator()->init(TensorInfo(TensorShape{ (workspace_size + alignment /* FIXME: remove alignment after COMPMID-1088 */) }, 1, DataType::S8), alignment); + _workspace.allocator()->allocate(); _arm_gemm->set_working_space(reinterpret_cast(_workspace.buffer())); } -- cgit v1.2.1