From 7cd26d4a1b14bc4bf7c61496803416ab3d84791f Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Wed, 9 Jan 2019 18:35:17 +0000 Subject: COMPMID-1867: Add NEON/SVE GEMM Hybrid kernels. Change-Id: Ib40a9921e7f9a6a8be6c38872d6b3a0f24ed0cd3 Reviewed-on: https://review.mlplatform.org/515 Reviewed-by: Anthony Barbier Tested-by: Arm Jenkins --- arm_compute/core/NEON/kernels/assembly/Helpers.h | 41 +++--- .../NEGEMMInterleavedMatrixMultiplyWrapper.h | 130 ++++++++++++++++--- .../NEGEMMInterleavedPrepareBWrapperKernel.h | 140 +++++++++++++++++++-- .../assembly/NEGEMMInterleavedTransformAWrapper.h | 71 ++++++++++- .../core/NEON/kernels/assembly/arm_gemm.hpp | 88 ++++++++----- .../core/NEON/kernels/assembly/gemm_common.hpp | 11 +- .../functions/assembly/NEGEMMInterleavedWrapper.h | 17 ++- 7 files changed, 400 insertions(+), 98 deletions(-) (limited to 'arm_compute') diff --git a/arm_compute/core/NEON/kernels/assembly/Helpers.h b/arm_compute/core/NEON/kernels/assembly/Helpers.h index 11c4c08086..e2a46e96a3 100644 --- a/arm_compute/core/NEON/kernels/assembly/Helpers.h +++ b/arm_compute/core/NEON/kernels/assembly/Helpers.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -27,6 +27,9 @@ #include "arm_compute/core/CPP/CPPTypes.h" #include "arm_compute/core/Utils.h" +#include "arm_compute/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h" +#include "arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp" + namespace arm_compute { /** Block sizes to use to break the M, N, K dimension */ @@ -38,31 +41,29 @@ struct BlockSizes unsigned int strategy_out_height{ 0 }; /**< Number of rows (M) processed by the selected strategy */ }; -/** Calculate the recommended block sizes to use based on the CPU cache sizes and data type - * - * @param[in] ci CPU information - * @param[in] M M dimension. - * @param[in] N N dimension. - * @param[in] K K dimension. - * @param[in] input_type Input data type - * @param[in] use_dot (Optional) If data_type is QASYMM8/U8/S8, then use the dot product instruction ? - * - * @return Recommeded block sizes to use for the given M, N, K dimensions. - */ -BlockSizes calculate_block_sizes_from_data_type(const CPUInfo &ci, unsigned int M, unsigned int N, unsigned int K, DataType input_type, bool use_dot = false); - -/** Get the name of the GEMM strategy which will be used for a given input type +/** Extracts the kernel description of the selected kernel by the GEMM backend heuristics * - * @param[in] input_type Input data type - * @param[in] use_dot (Optional) If data_type is QASYMM8/U8/S8, then use the dot product instruction ? + * @param[in] input_type Data type of the input tensor. + * @param[in] ci CPU information. + * @param[in] num_threads Maximum number of threads that might be used for the calculations. + * @param[in] p M, N, K sizes. + * @param[in] alpha Alpha value. + * @param[in] beta Beta value. + * @param[in] pretranspose_hint Is B also pretransposed ? * - * @return The name of the strategy that will be used + * @return Kernel description that the assembly heuristics picked for the given configuration */ -const char *get_strategy_name(DataType input_type, bool use_dot = false); +arm_gemm::KernelDescription get_gemm_info(DataType input_type, + const CPUInfo &ci, + const unsigned int num_threads, + const INEGEMMWrapperKernel::Params &p, + float alpha, + float beta, + bool pretranspose_hint); /** Calculate the recommended block sizes to use based on the CPU cache sizes and the strategy which will be used * - * @param[in] ci CPU information + * @param[in] ci CPU information. * @param[in] M M dimension. * @param[in] N N dimension. * @param[in] K K dimension. diff --git a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h index 46a05abcdb..e2b849aa3d 100644 --- a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h +++ b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -26,8 +26,13 @@ #include "arm_compute/core/NEON/kernels/assembly/Helpers.h" +#include "arm_compute/core/Helpers.h" +#include "arm_compute/core/ITensor.h" #include "arm_compute/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h" +#include "arm_compute/core/Utils.h" +#include "arm_compute/core/Validate.h" #include "arm_compute/core/Window.h" +#include "arm_compute/core/WindowIterator.h" namespace arm_compute { @@ -84,7 +89,7 @@ public: }; /** Equivalent to arm_gemm::GemmInterleaved's strategy::kernel() but using Compute Library types. */ -template +template class NEGEMMInterleavedMatrixMultiplyWrapperTemplate : public NEGEMMInterleavedMatrixMultiplyWrapper { public: @@ -94,7 +99,7 @@ public: * @param[in] transformed_b Already reshaped matrix B. * @param[out] tmp_c Temporary buffer to be used to store intermediate results. * @param[in,out] c Result matrix C. - * @param[in] batch_window Window containing iteration information for the M and batch dimensions. + * @param[in] block_walker Window containing iteration information for the M and batch dimensions. * @param[in] block_sizes Block sizes to use for the matrix multiplication (A & B must have been reshaped using these same block sizes). * @param[in] params M, N, K sizes. * @param[in] is_pretransposed Is B also pretransposed ? @@ -102,30 +107,117 @@ public: * @param[in] beta Beta value * @param[in] max_num_threads Maximum number of threads that might be used for the calculations. */ - void configure(const ITensor *prepared_a, const ITensor *transformed_b, ITensor *tmp_c, ITensor *c, const Window &batch_window, const BlockSizes &block_sizes, - const INEGEMMWrapperKernel::Params ¶ms, bool b_is_pretransposed, float alpha, float beta, unsigned int max_num_threads); + void configure(const ITensor *prepared_a, const ITensor *transformed_b, ITensor *tmp_c, ITensor *c, const Window &block_walker, const BlockSizes &block_sizes, + const INEGEMMWrapperKernel::Params ¶ms, bool b_is_pretransposed, float alpha, float beta, unsigned int max_num_threads) + { + _prepared_a = prepared_a; + _transformed_b = transformed_b; + _tmp_c = tmp_c; + _c = c; + _block_walker = block_walker; + _block_sizes = block_sizes; + _params = params; + _b_is_pretransposed = b_is_pretransposed; + _alpha = alpha; + _beta = beta; + + auto_init_if_empty(*_tmp_c->info(), c->info()->clone()->set_tensor_shape(TensorShape{ _block_sizes.x_block * strategy::out_height(), max_num_threads })); + } // Inherited methods overridden: - void transform(const MatrixMultiplyWorkload &wl, const ThreadInfo &info, const Window &batch_window, const Coordinates &start_offset, const Coordinates &end_offset) override; - void create_workloads(std::vector &workloads) override; + void transform(const MatrixMultiplyWorkload &wl, const ThreadInfo &info, const Window &batch_window, const Coordinates &start_offset, const Coordinates &end_offset) override + { + strategy strat(info.cpu_info); + TensorAccessor prepared_a(*_prepared_a); + TensorAccessor transformed_b(*_transformed_b); + TensorAccessor c(*_c); + TensorAccessor tmp_c(*_tmp_c); + + int prev_batch = -1; + typename strategy::operand_type *a_ptr = nullptr; + auto window_iterator = arm_compute::create_window_iterator(batch_window, start_offset, end_offset, [&](const Coordinates & id) + { + const unsigned int y = id.x(); + const unsigned int batch = id.y(); + const unsigned int ymax = std::min(_params.M, y + strategy::out_height()); + + // If it's the first block of a new batch then reset the pointer to A. + if(prev_batch != static_cast(batch)) + { + const unsigned int first_m = id.x(); + a_ptr = prepared_a(0, first_m, batch); + prev_batch = batch; + } + + // Call matrix multiply assembly routine to process the block: + strat.kernel(a_ptr, transformed_b(wl._offset_transformed_b), tmp_c(0, info.thread_id), 1, wl._bblocks, wl._kern_k); + a_ptr += strategy::out_height() * wl._kern_k; + + // Merge the result with the other blocks' results: + strat.transforms.Merge(c(0, 0, batch, wl._multi), tmp_c(0, info.thread_id), c.stride(1), y, ymax, wl._x0, wl._xmax, _alpha, (wl._k0 == 0 ? _beta : static_cast(1))); + }); + auto on_new_row_size = [&](unsigned int start, unsigned int end) + { + //Nothing to do + }; + window_iterator.iterate_2D(on_new_row_size); + } + void create_workloads(std::vector &workloads) override + { + unsigned int offset_transformed_b = 0; + unsigned int wl_index = 0; + unsigned int num_buffers = 0, reshaped_block_size = 0; + + if(!_b_is_pretransposed) + { + num_buffers = _transformed_b->info()->tensor_shape()[1]; + reshaped_block_size = _transformed_b->info()->tensor_shape()[0]; + } + execute_window_loop(_block_walker, [&](const Coordinates & id) + { + const unsigned int x0 = id.x(); + const unsigned int k0 = id.y(); + const unsigned int multi = id.z(); + + const unsigned int xmax = std::min(x0 + _block_walker.x().step(), _params.N); + const unsigned int kmax = std::min(k0 + _block_walker.y().step(), _params.K); + + // Figure out how many "K" the kernel will actually process. + const int kern_k = ceil_to_multiple(kmax - k0, strategy::k_unroll()); + const int bblocks = DIV_CEIL(xmax - x0, strategy::out_width()); + + workloads.push_back(MatrixMultiplyWorkload(offset_transformed_b, x0, xmax, k0, kmax, multi, kern_k, bblocks)); + + if(_b_is_pretransposed) + { + offset_transformed_b += bblocks * strategy::out_width() * kern_k; + } + else + { + // Rotate through the BufferManager's buffers: + wl_index++; + offset_transformed_b = (wl_index % num_buffers) * reshaped_block_size; + } + }); + } private: const ITensor *_prepared_a { nullptr }; - const ITensor *_transformed_b{ nullptr }; - ITensor *_tmp_c{ nullptr }; - ITensor *_c{ nullptr }; - unsigned int _Nsize{ 0 }; - unsigned int _Ksize{ 0 }; - bool _transpose_b{ false }; - BlockSizes _block_sizes{}; - INEGEMMWrapperKernel::Params _params{}; - Window _block_walker{}; - bool _b_is_pretransposed{ false }; - Tr _alpha{}; - Tr _beta{}; + const ITensor *_transformed_b{ nullptr }; + ITensor *_tmp_c{ nullptr }; + ITensor *_c{ nullptr }; + unsigned int _Nsize{ 0 }; + unsigned int _Ksize{ 0 }; + bool _transpose_b{ false }; + BlockSizes _block_sizes{}; + INEGEMMWrapperKernel::Params _params{}; + Window _block_walker{}; + bool _b_is_pretransposed{ false }; + typename strategy::result_type _alpha{}; + typename strategy::result_type _beta{}; }; } // namespace arm_compute diff --git a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.h b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.h index e46c33018b..ba3223f66d 100644 --- a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.h +++ b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -24,14 +24,16 @@ #ifndef __ARM_COMPUTE_NEGEMMINTERLEAVEDPREPAREBWRAPPERKERNEL_H__ #define __ARM_COMPUTE_NEGEMMINTERLEAVEDPREPAREBWRAPPERKERNEL_H__ +#include "arm_compute/core/Helpers.h" +#include "arm_compute/core/ITensor.h" #include "arm_compute/core/NEON/INEKernel.h" #include "arm_compute/core/NEON/kernels/assembly/Helpers.h" #include "arm_compute/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h" +#include "arm_compute/core/Utils.h" +#include "arm_compute/core/Validate.h" namespace arm_compute { -class ITensor; - /** Unit of work for @ref NEGEMMInterleavedPrepareBWrapperKernel to process */ struct PrepareBWorkload { @@ -56,6 +58,84 @@ struct PrepareBWorkload unsigned int _kmax; /**< Last value to process along the K dimension. */ }; +namespace detail +{ +// Call the lambda function for each workload generated by the passed window. +template +void for_each_element_in_window(const Window &window, const ITensor *b, ITensor *transformed_b, unsigned int N, unsigned int K, Lambda &&lambda) +{ + unsigned int wl_index = 0; + unsigned int num_buffers = 0, reshaped_block_size = 0; + + if(use_buffer_manager) + { + num_buffers = transformed_b->info()->tensor_shape()[1]; + reshaped_block_size = transformed_b->info()->strides_in_bytes().y(); + } + + unsigned int offset_transformed_b = transformed_b->info()->offset_first_element_in_bytes(); + execute_window_loop(window, [&](const Coordinates & coordinates) + { + const unsigned int x0 = coordinates.x(); + const unsigned int k0 = coordinates.y(); + const unsigned int multi = coordinates.z(); + + const unsigned int offset_b = b->info()->offset_element_in_bytes(Coordinates(0, 0, multi)); + const unsigned int xmax = std::min(x0 + window.x().step(), N); + const unsigned int kmax = std::min(k0 + window.y().step(), K); + + /* Figure out the size of each block. */ + unsigned int x_size = (xmax - x0); + unsigned int k_size = (kmax - k0); + + /* Round sizes up as needed. */ + x_size = ceil_to_multiple(x_size, strategy::out_width()); + k_size = ceil_to_multiple(k_size, strategy::k_unroll()); + + lambda(PrepareBWorkload(offset_b, offset_transformed_b, x0, xmax, k0, kmax)); + + //Each workload represents one block: + if(use_buffer_manager) + { + // Rotate through the BufferManager's buffers: + wl_index++; + offset_transformed_b = (wl_index % num_buffers) * reshaped_block_size; + } + else + { + offset_transformed_b += (x_size * k_size * sizeof(typename strategy::operand_type)); + } + }); +} + +// Calculate the size of transformed_b: +template +unsigned int get_B_pretransposed_array_size(unsigned int N, unsigned int K, const BlockSizes &bs, unsigned int multis) +{ + // How many full blocks do N / K contain ? + size_t num_full_k = K / bs.k_block; + size_t num_full_x = N / bs.x_block; + + ARM_COMPUTE_ERROR_ON(bs.x_block % strategy::out_width() != 0); + ARM_COMPUTE_ERROR_ON(bs.k_block % strategy::k_unroll() != 0); + + size_t normal_x_size = bs.x_block; + size_t normal_k_size = bs.k_block; + + // Round up the leftovers to be a multiple of the strategy processing size: + size_t left_over_x_size = ceil_to_multiple(N % bs.x_block, strategy::out_width()); + size_t left_over_k_size = ceil_to_multiple(K % bs.k_block, strategy::k_unroll()); + + // Calculate the total size of the buffer: + size_t total = num_full_k * normal_k_size * (num_full_x * normal_x_size + left_over_x_size); + total += left_over_k_size * (left_over_x_size + num_full_x * normal_x_size); + + total *= multis; + + return total; +} +} // namespace detail + /** Common interface for the templated wrappers around the B reshape NEON assembly implementations */ class NEGEMMInterleavedPrepareBWrapperKernel : public INEKernel { @@ -93,7 +173,7 @@ public: /** Equivalent to arm_gemm::GemmInterleaved's strategy::transforms::PrepareB() but using Compute Library types. */ -template +template class NEGEMMInterleavedPrepareBWrapperKernelTemplate : public NEGEMMInterleavedPrepareBWrapperKernel { public: @@ -105,13 +185,55 @@ public: * @param[in] ci CPU information * @param[in] params M, N, K sizes. */ - void configure(const ITensor *b, ITensor *transformed_b, bool transpose_b, const CPUInfo &ci, const INEGEMMWrapperKernel::Params ¶ms); + void configure(const ITensor *b, ITensor *transformed_b, bool transpose_b, const CPUInfo &ci, const INEGEMMWrapperKernel::Params ¶ms) + { + const unsigned int multis = b->info()->tensor_shape().z(); + _Nsize = b->info()->tensor_shape().x(); + _Ksize = b->info()->tensor_shape().y(); + _b = b; + _transformed_b = transformed_b; + _transpose_b = transpose_b; + + _block_sizes = calculate_block_sizes(ci, params.M, params.N, params.K); + + auto_init_if_empty(*transformed_b->info(), b->info()->clone()->set_tensor_shape(TensorShape{ detail::get_B_pretransposed_array_size(_Nsize, _Ksize, _block_sizes, multis) })); + + Window window; + window.set(Window::DimX, Window::Dimension(0, ceil_to_multiple(_Nsize, _block_sizes.x_block), _block_sizes.x_block)); + window.set(Window::DimY, Window::Dimension(0, ceil_to_multiple(_Ksize, _block_sizes.k_block), _block_sizes.k_block)); + window.set(Window::DimZ, Window::Dimension(0, multis)); + + INEKernel::configure(window); + } // Inherited methods overridden: - void transform(const PrepareBWorkload &wl, const ThreadInfo &info) override; - void create_workloads(std::vector &workloads) override; - void run(const Window &window, const ThreadInfo &info) override; - BlockSizes block_sizes() const override; + void transform(const PrepareBWorkload &wl, const ThreadInfo &info) override + { + strategy strat(info.cpu_info); + strat.transforms.PrepareB(reinterpret_cast(_transformed_b->buffer() + wl._offset_transformed_b), + reinterpret_cast(_b->buffer() + wl._offset_b), + _b->info()->strides_in_bytes().y() / sizeof(typename strategy::operand_type), + wl._x0, wl._xmax, wl._k0, wl._kmax, _transpose_b); + } + void create_workloads(std::vector &workloads) override + { + detail::for_each_element_in_window(window(), _b, _transformed_b, _Nsize, _Ksize, [&workloads](PrepareBWorkload && wl) + { + workloads.push_back(std::move(wl)); + }); + } + void run(const Window &window, const ThreadInfo &info) override + { + ARM_COMPUTE_ERROR_ON_MISMATCHING_WINDOWS(window, INEKernel::window()); + detail::for_each_element_in_window(window, _b, _transformed_b, _Nsize, _Ksize, [&](PrepareBWorkload && wl) + { + this->transform(wl, info); + }); + } + BlockSizes block_sizes() const override + { + return _block_sizes; + } private: const ITensor *_b diff --git a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h index b6831e3ca9..5d6cd02398 100644 --- a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h +++ b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -25,8 +25,13 @@ #define __ARM_COMPUTE_NEGEMMINTERLEAVEDTRANSFORMAWRAPPER_H__ #include "arm_compute/core/CPP/CPPTypes.h" +#include "arm_compute/core/Helpers.h" +#include "arm_compute/core/ITensor.h" #include "arm_compute/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h" +#include "arm_compute/core/Utils.h" +#include "arm_compute/core/Validate.h" #include "arm_compute/core/Window.h" +#include "arm_compute/core/WindowIterator.h" namespace arm_compute { @@ -76,7 +81,7 @@ public: }; /** Type specialisations of @ref NEGEMMInterleavedTransformAWrapper */ -template +template class NEGEMMInterleavedTransformAWrapperTemplate : public NEGEMMInterleavedTransformAWrapper { public: @@ -88,11 +93,67 @@ public: * @param[in] block_walker Window representing the layout of the matrix's blocks * @param[in] params M, N, K sizes. */ - void configure(const ITensor *a, ITensor *transformed_a, bool transpose_a, const Window &block_walker, const INEGEMMWrapperKernel::Params ¶ms); + void configure(const ITensor *a, ITensor *transformed_a, bool transpose_a, const Window &block_walker, const INEGEMMWrapperKernel::Params ¶ms) + { + _a = a; + _transformed_a = transformed_a; + _transpose_a = transpose_a; + _Ksize = params.K; + _Msize = params.M; + _k_multi_window = block_walker.shift_dimensions(1); // block_walker contains (M,K,Multi) --> shift by 1 to get rid of the "M" dimension + } // Inherited methods overridden: - void transform(const TransformAWorkload &wl, const ThreadInfo &info, const Window &batch_window, const Coordinates &start_offset, const Coordinates &end_offset) override; - void create_workloads(std::vector &workloads) override; + void transform(const TransformAWorkload &wl, const ThreadInfo &info, const Window &batch_window, const Coordinates &start_offset, const Coordinates &end_offset) override + { + strategy strat(info.cpu_info); + TensorAccessor a(*_a); + TensorAccessor transformed_a(*_transformed_a); + + if(_a->info()->data_layout() == DataLayout::NHWC) + { + // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is + // the relevant multiple of the row stride. + const size_t nhwc_batch_stride = _a->info()->strides_in_bytes().y() * _Msize; + a.set_stride(2, nhwc_batch_stride); + } + + unsigned int last_m = 0; + //TODO: Create a new iterate_1D( DimY); + int last_y = -1; + auto window_iterator = arm_compute::create_window_iterator(batch_window, start_offset, end_offset, [&](const Coordinates & id) + { + if(id.y() != last_y) + { + last_y = id.y(); + unsigned int batch = id.y(); + unsigned int first_m = id.x(); + + if(first_m >= last_m) + return; + + strat.transforms.PrepareA(transformed_a(0, first_m, batch), + a(0, 0, batch, wl._multi), + a.stride(1), first_m, last_m, wl._k0, wl._kmax, _transpose_a); + } + }); + auto on_new_row_size = [&](unsigned int start, unsigned int end) + { + last_m = std::min(end, _Msize); + }; + window_iterator.iterate_2D(on_new_row_size); + } + void create_workloads(std::vector &workloads) override + { + execute_window_loop(_k_multi_window, [&](const Coordinates & id) + { + const unsigned int k0 = id.x(); + const unsigned int multi = id.y(); + const unsigned int kmax = std::min(k0 + _k_multi_window.x().step(), _Ksize); + + workloads.push_back(TransformAWorkload(k0, kmax, multi)); + }); + } private: const ITensor *_a diff --git a/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp b/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp index 162cbc5c46..26c1f3df89 100644 --- a/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp +++ b/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -24,6 +24,7 @@ #pragma once #include +#include #include "arm_gemm_local.hpp" #include "gemm_common.hpp" @@ -37,45 +38,57 @@ enum class GemmMethod GEMV_PRETRANSPOSED, GEMV_NATIVE_TRANSPOSED, GEMM_NATIVE, - GEMM_INTERLEAVED, - GEMM_INTERLEAVED_FP16, - GEMM_INTERLEAVED_DOT + GEMM_HYBRID, + GEMM_INTERLEAVED +}; + + +struct KernelDescription +{ + GemmMethod method = GemmMethod::DEFAULT; + std::string name = ""; + + KernelDescription(GemmMethod m, std::string n) : method(m), name(n) { } + KernelDescription() { } }; struct GemmConfig { - GemmMethod method = GemmMethod::DEFAULT; + GemmMethod method = GemmMethod::DEFAULT; + std::string filter = ""; unsigned int inner_block_size = 0; unsigned int outer_block_size = 0; GemmConfig(GemmMethod method) : method(method) { } + GemmConfig() { } }; template struct GemmArgs { public: - const CPUInfo *_ci; - unsigned int _Msize; - unsigned int _Nsize; - unsigned int _Ksize; - unsigned int _nbatches; - unsigned int _nmulti; - bool _trA; - bool _trB; - T _alpha; - T _beta; - int _maxthreads; - bool _pretransposed_hint; + const CPUInfo *_ci; + unsigned int _Msize; + unsigned int _Nsize; + unsigned int _Ksize; + unsigned int _nbatches; + unsigned int _nmulti; + bool _trA; + bool _trB; + T _alpha; + T _beta; + int _maxthreads; + bool _pretransposed_hint; + const GemmConfig *_cfg; GemmArgs(const CPUInfo *ci, const unsigned int M, const unsigned int N, const unsigned int K, const unsigned int nbatches, const unsigned int nmulti, const bool trA, const bool trB, const T alpha, const T beta, const int maxthreads, - const bool pretransposed_hint) : - _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmulti(nmulti), - _trA(trA), _trB(trB), _alpha(alpha), _beta(beta), _maxthreads(maxthreads), - _pretransposed_hint(pretransposed_hint) + const bool pretransposed_hint, const GemmConfig *cfg=nullptr ) : + _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmulti(nmulti), + _trA(trA), _trB(trB), _alpha(alpha), _beta(beta), _maxthreads(maxthreads), + _pretransposed_hint(pretransposed_hint), _cfg(cfg) { } }; @@ -90,7 +103,7 @@ using UniqueGemmCommon = std::unique_ptr >; * provided parameters be provided using the supplied method? */ template -bool method_is_compatible(GemmMethod method, GemmArgs &args); +bool method_is_compatible(GemmMethod method, const GemmArgs &args); template bool method_is_compatible(GemmMethod method, const CPUInfo &ci, @@ -107,14 +120,14 @@ bool method_is_compatible(GemmMethod method, const CPUInfo &ci, /* get_gemm_method(): Given the templated types and provided parameters, * which is the preferred method to implement this GEMM? */ template -GemmMethod get_gemm_method(GemmArgs &args); +KernelDescription get_gemm_method(const GemmArgs &args); template -GemmMethod get_gemm_method(const CPUInfo &ci, - const unsigned int M, const unsigned int N, const unsigned int K, - const unsigned int nbatches, const unsigned int nmulti, - const bool trA, const bool trB, const Tret alpha, const Tret beta, - const int maxthreads, const bool pretransposed_hint) +KernelDescription get_gemm_method(const CPUInfo &ci, + const unsigned int M, const unsigned int N, const unsigned int K, + const unsigned int nbatches, const unsigned int nmulti, + const bool trA, const bool trB, const Tret alpha, const Tret beta, + const int maxthreads, const bool pretransposed_hint) { GemmArgs args(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint); @@ -122,7 +135,7 @@ GemmMethod get_gemm_method(const CPUInfo &ci, } template -UniqueGemmCommon gemm(GemmArgs &args, GemmConfig *cfg); +UniqueGemmCommon gemm(const GemmArgs &args); /** Request an object to process a GEMM. * @@ -146,10 +159,25 @@ UniqueGemmCommon gemm(const CPUInfo &ci, const unsigned int nbatches, const unsigned int nmulti, const bool trA, const bool trB, const Tret alpha, const Tret beta, const int maxthreads, const bool pretransposed_hint, GemmConfig *cfg=nullptr) +{ + GemmArgs args(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint, cfg); + + return gemm(args); +} + +template +std::vector get_compatible_kernels(const GemmArgs &args); + +template +std::vector get_compatible_kernels(const CPUInfo &ci, + const unsigned int M, const unsigned int N, const unsigned int K, + const unsigned int nbatches, const unsigned int nmulti, + const bool trA, const bool trB, const Tret alpha, const Tret beta, + const int maxthreads, const bool pretransposed_hint) { GemmArgs args(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint); - return gemm(args, cfg); + return get_compatible_kernels(args); } } // namespace arm_gemm diff --git a/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp b/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp index b43d6eaca6..7b4f0149e3 100644 --- a/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp +++ b/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -88,11 +88,11 @@ public: * This has an empty default implementation, as GEMMs which don't care * about thread count can safely ignore this. */ - virtual void set_nthreads(int nthreads) { }; + virtual void set_nthreads(int) { }; /* Actually do the work. Provide a threadid to index any per-thread * buffers, and a start/end range to indicate which work to do. */ - virtual void execute(unsigned int start, unsigned int end, int threadid) = 0; + virtual void execute(unsigned int, unsigned int, int) = 0; /*** Working space interface (optional) ***/ /* Total number of bytes of temporary working space needed. If zero, it's not necessary to call set_working_space(). */ @@ -108,9 +108,10 @@ public: /* Total number of bytes of space needed for pretransposed arrays. */ virtual size_t get_B_pretransposed_array_size() const { return 0; } /* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */ - virtual void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) { }; + /* Arguments are: output buffer pointer, source pointer, source row stride, source multi stride */ + virtual void pretranspose_B_array(void *, const To *, const int, const int) { }; /* Set pretransposed data - the void * passed in must previously have been passed to pretranspose_B_array() for the same or a similar GEMM. */ - virtual void set_pretransposed_B_data(void *buffer) { } + virtual void set_pretransposed_B_data(void *) { } // Destructor virtual ~GemmCommon() { } diff --git a/arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h b/arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h index 26236ffb35..3ccfbc512b 100644 --- a/arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h +++ b/arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -26,6 +26,9 @@ #include "arm_compute/core/NEON/kernels/assembly/Helpers.h" #include "arm_compute/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h" +#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h" +#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.h" +#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h" #include "arm_compute/runtime/IFunction.h" #include "arm_compute/runtime/IMemoryManager.h" #include "arm_compute/runtime/IScheduler.h" @@ -36,13 +39,8 @@ namespace arm_compute { +// Forward declarations class ITensor; -class NEGEMMInterleavedPrepareBWrapperKernel; -class PrepareBWorkload; -class TransformAWorkload; -class MatrixMultiplyWorkload; -class NEGEMMInterleavedTransformAWrapper; -class NEGEMMInterleavedMatrixMultiplyWrapper; /** Buffer manager used when reshaping B on the fly * @@ -97,6 +95,7 @@ class NEGEMMInterleavedWrapper : public IFunction { public: NEGEMMInterleavedWrapper(std::shared_ptr memory_manager = nullptr); + ~NEGEMMInterleavedWrapper() = default; NEGEMMInterleavedWrapper(const NEGEMMInterleavedWrapper &) = delete; NEGEMMInterleavedWrapper &operator=(const NEGEMMInterleavedWrapper &) = delete; @@ -111,9 +110,8 @@ public: * @param[in] alpha Scalar multiplier to apply to AB matrix product. * @param[in] beta Scalar multiplier to apply to input C matrix before adding product. * @param[in] pretranspose_b If true, pretranspose B once during the prepare() stage instead of on the fly every time. - * @param[in] use_dot (Optional) If the input's type is U8/S8/QASYMM8 then use the dot product flavour or the matrix multiply routine. (Must be supported by the hardware). */ - void configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, bool pretranspose_b, bool use_dot = false); + void configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, bool pretranspose_b); // Inherited methods overridden: void run() override; @@ -143,6 +141,5 @@ private: std::vector _workloads{}; std::string _tag{}; }; - } // namespace arm_compute #endif /* __ARM_COMPUTE_NEGEMMINTERLEAVEDWRAPPER_H__ */ -- cgit v1.2.1