From c0b6f76561580414f08633a804fc548ccad65659 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Mon, 2 Nov 2020 01:37:17 +0000 Subject: COMPMID-3776: Indirect GEMM Signed-off-by: Georgios Pinitas Change-Id: I51a1b0f098bc3a8c408c50c92221e4df3061e12c Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4343 Tested-by: Arm Jenkins Reviewed-by: Sang-Hoon Park Reviewed-by: Michele Di Giorgio Comments-Addressed: Arm Jenkins --- src/runtime/NEON/functions/NEConvolutionLayer.cpp | 61 ++-- src/runtime/NEON/functions/NEGEMM.cpp | 26 +- .../NEON/functions/NEGEMMAssemblyDispatch.cpp | 328 +++++++++++++++++---- src/runtime/NEON/functions/NEGEMMConv2d.cpp | 167 +++++++++++ .../NEGEMMLowpAssemblyMatrixMultiplyCore.cpp | 142 --------- .../functions/NEGEMMLowpMatrixMultiplyCore.cpp | 28 +- .../NEON/functions/NESimpleAssemblyFunction.cpp | 46 --- .../NEON/functions/NESimpleAssemblyFunction.h | 56 ---- 8 files changed, 514 insertions(+), 340 deletions(-) create mode 100644 src/runtime/NEON/functions/NEGEMMConv2d.cpp delete mode 100644 src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp delete mode 100644 src/runtime/NEON/functions/NESimpleAssemblyFunction.cpp delete mode 100644 src/runtime/NEON/functions/NESimpleAssemblyFunction.h (limited to 'src/runtime/NEON/functions') diff --git a/src/runtime/NEON/functions/NEConvolutionLayer.cpp b/src/runtime/NEON/functions/NEConvolutionLayer.cpp index 901b1e880e..cc5f160787 100644 --- a/src/runtime/NEON/functions/NEConvolutionLayer.cpp +++ b/src/runtime/NEON/functions/NEConvolutionLayer.cpp @@ -27,27 +27,12 @@ #include "arm_compute/core/Utils.h" #include "arm_compute/core/Validate.h" #include "arm_compute/runtime/NEON/NEScheduler.h" -#include "src/core/NEON/kernels/NECol2ImKernel.h" -#include "src/core/NEON/kernels/NEConvertQuantizedSignednessKernel.h" -#include "src/core/NEON/kernels/NECopyKernel.h" -#include "src/core/NEON/kernels/NEDirectConvolutionLayerKernel.h" -#include "src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.h" -#include "src/core/NEON/kernels/NEFFTDigitReverseKernel.h" -#include "src/core/NEON/kernels/NEFFTRadixStageKernel.h" -#include "src/core/NEON/kernels/NEFFTScaleKernel.h" -#include "src/core/NEON/kernels/NEFillBorderKernel.h" -#include "src/core/NEON/kernels/NEGEMMInterleave4x4Kernel.h" -#include "src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.h" -#include "src/core/NEON/kernels/NEGEMMLowpOffsetContributionKernel.h" -#include "src/core/NEON/kernels/NEGEMMLowpOffsetContributionOutputStageKernel.h" -#include "src/core/NEON/kernels/NEGEMMLowpReductionKernel.h" -#include "src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.h" -#include "src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h" -#include "src/core/NEON/kernels/NEGEMMTranspose1xWKernel.h" -#include "src/core/NEON/kernels/NEIm2ColKernel.h" -#include "src/core/NEON/kernels/NEPadLayerKernel.h" -#include "src/core/NEON/kernels/NEReductionOperationKernel.h" -#include "src/core/NEON/kernels/NEWeightsReshapeKernel.h" +#include "arm_compute/runtime/NEON/functions/NEDirectConvolutionLayer.h" +#include "arm_compute/runtime/NEON/functions/NEFFTConvolutionLayer.h" +#include "arm_compute/runtime/NEON/functions/NEGEMMConv2d.h" +#include "arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h" +#include "arm_compute/runtime/NEON/functions/NEWinogradConvolutionLayer.h" + #include "support/MemorySupport.h" #include @@ -71,6 +56,7 @@ void NEConvolutionLayer::configure(ITensor *input, const ITensor *weights, const ARM_COMPUTE_ERROR_THROW_ON(NEConvolutionLayer::validate(input->info(), weights->info(), ((biases != nullptr) ? biases->info() : nullptr), output->info(), conv_info, weights_info, dilation, act_info, enable_fast_math, num_groups)); + const Conv2dInfo info(conv_info, dilation, act_info, enable_fast_math, num_groups); switch(NEConvolutionLayer::get_convolution_method(input->info(), weights->info(), output->info(), conv_info, weights_info, dilation, act_info, enable_fast_math)) { case ConvolutionMethod::WINOGRAD: @@ -87,6 +73,13 @@ void NEConvolutionLayer::configure(ITensor *input, const ITensor *weights, const _function = std::move(f); break; } + case ConvolutionMethod::GEMM_CONV2D: + { + auto f = arm_compute::support::cpp14::make_unique(_memory_manager); + f->configure(input, weights, biases, output, info); + _function = std::move(f); + break; + } case ConvolutionMethod::DIRECT: { auto f = arm_compute::support::cpp14::make_unique(_memory_manager); @@ -112,22 +105,22 @@ Status NEConvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo { ARM_COMPUTE_RETURN_ERROR_ON_MSG((num_groups != 1), "Grouping (num_groups != 1) is not supported on NEON"); + const Conv2dInfo info(conv_info, dilation, act_info, enable_fast_math, num_groups); switch(NEConvolutionLayer::get_convolution_method(input, weights, output, conv_info, weights_info, dilation, act_info, enable_fast_math)) { case ConvolutionMethod::WINOGRAD: - //Validate Winograd ARM_COMPUTE_RETURN_ON_ERROR(NEWinogradConvolutionLayer::validate(input, weights, biases, output, conv_info, act_info, enable_fast_math)); break; case ConvolutionMethod::GEMM: - //Validate Gemm-based Convolution ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMConvolutionLayer::validate(input, weights, biases, output, conv_info, weights_info, dilation, act_info)); break; + case ConvolutionMethod::GEMM_CONV2D: + ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMConv2d::validate(input, weights, biases, output, info)); + break; case ConvolutionMethod::DIRECT: - //Validate Direct Convolution ARM_COMPUTE_RETURN_ON_ERROR(NEDirectConvolutionLayer::validate(input, weights, biases, output, conv_info, act_info)); break; case ConvolutionMethod::FFT: - // Validate FFT-based convolution layer ARM_COMPUTE_RETURN_ON_ERROR(NEFFTConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info)); break; default: @@ -149,6 +142,8 @@ ConvolutionMethod NEConvolutionLayer::get_convolution_method(const ITensorInfo * const size_t idx_h = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT); const size_t idx_c = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL); + const Conv2dInfo info(conv_info, dilation, act_info, enable_fast_math, 1); + /* Input spatial dims, kernel size, IFM/OFM, conv info*/ using ConvolutionConfiguration = std::tuple; using ConfigurationMethod = std::pair; @@ -235,7 +230,21 @@ ConvolutionMethod NEConvolutionLayer::get_convolution_method(const ITensorInfo * } } #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - return bool(NEWinogradConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info, enable_fast_math)) ? ConvolutionMethod::WINOGRAD : ConvolutionMethod::GEMM; + // For 1x1 convolutions run the default GEMM + if(weights->dimension(idx_w) == 1 && weights->dimension(idx_h) == 1) + { + return ConvolutionMethod::GEMM; + } + + if(bool(NEWinogradConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info, enable_fast_math))) + { + return ConvolutionMethod::WINOGRAD; + } + if(bool(NEGEMMConv2d::validate(input, weights, nullptr, output, info))) + { + return ConvolutionMethod::GEMM_CONV2D; + } + return ConvolutionMethod::GEMM; } } diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp index 0215098792..9f52e458d2 100644 --- a/src/runtime/NEON/functions/NEGEMM.cpp +++ b/src/runtime/NEON/functions/NEGEMM.cpp @@ -47,7 +47,19 @@ using namespace arm_compute::misc::shape_calculator; namespace arm_compute { -NEGEMM::~NEGEMM() = default; +namespace +{ +AsmGemmInfo init_assembly_metadata(const GEMMInfo &info) +{ + AsmGemmInfo asm_info; + asm_info.method = AsmConvMethod::Im2Col; + asm_info.reinterpret_input_as_3d = info.reinterpret_input_as_3d(); + asm_info.depth_output_gemm3d = info.depth_output_gemm3d(); + asm_info.activation_info = info.activation_info(); + + return asm_info; +} +} // namespace NEGEMM::NEGEMM(std::shared_ptr memory_manager, IWeightsManager *weights_manager) : _memory_group(memory_manager), _weights_manager(weights_manager), _interleave_kernel(), _transpose_kernel(), _mm_kernel(), _asm_glue(memory_manager, weights_manager), _ma_kernel(), @@ -56,12 +68,15 @@ NEGEMM::NEGEMM(std::shared_ptr memory_manager, IWeightsManager * { } +NEGEMM::~NEGEMM() = default; + void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info) { ARM_COMPUTE_ERROR_THROW_ON(NEGEMM::validate(a->info(), b->info(), (c != nullptr) ? c->info() : nullptr, d->info(), alpha, beta, gemm_info)); - const bool is_c_bias = gemm_info.reshape_b_only_on_first_run(); - bool run_optimised = bool(NEGEMMAssemblyDispatch::validate(a->info(), b->info(), (is_c_bias && c != nullptr) ? c->info() : nullptr, d->info(), gemm_info)); + const AsmGemmInfo asm_info = init_assembly_metadata(gemm_info); + const bool is_c_bias = gemm_info.reshape_b_only_on_first_run(); + bool run_optimised = bool(NEGEMMAssemblyDispatch::validate(a->info(), b->info(), (is_c_bias && c != nullptr) ? c->info() : nullptr, d->info(), asm_info)); // Check if we need to reshape the matrix B only on the first run _is_prepared = false; @@ -76,7 +91,7 @@ void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITe if(run_optimised) { const ITensor *c_to_use = is_c_bias ? c : nullptr; - _asm_glue.configure(a, b, c_to_use, d, gemm_info); + _asm_glue.configure(a, b, c_to_use, d, asm_info); ARM_COMPUTE_ERROR_ON(!_asm_glue.is_configured()); // Scale product by alpha @@ -221,7 +236,8 @@ Status NEGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso } // Check if we need to run the optimized assembly kernel - const bool run_optimised = bool(NEGEMMAssemblyDispatch::validate(a, b, is_c_bias ? c : nullptr, output, gemm_info)); + AsmGemmInfo asm_info = init_assembly_metadata(gemm_info); + const bool run_optimised = bool(NEGEMMAssemblyDispatch::validate(a, b, is_c_bias ? c : nullptr, output, asm_info)); if(!run_optimised) { diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp index 5b0848398d..400fa64438 100644 --- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp +++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp @@ -25,18 +25,70 @@ #include "arm_compute/runtime/NEON/NEScheduler.h" #include "src/core/CPP/Validate.h" -#include "src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h" #include "src/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h" #include "src/core/NEON/kernels/assembly/arm_gemm.hpp" #include "support/MemorySupport.h" #include +#include namespace arm_compute { namespace { +struct free_delete +{ + void operator()(void *x) + { + free(x); + } +}; + +struct Params +{ + unsigned int M; + unsigned int N; + unsigned int K; + unsigned int batches; + unsigned int multis; + unsigned int sections; + bool indirect; +}; + +Params extract_parameters(const ITensor *a, const ITensor *b, const ITensor *d, const AsmGemmInfo &info) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d); + + Params p; + p.K = a->info()->tensor_shape().x(); + p.N = d->info()->tensor_shape().x(); + p.multis = 1; + p.indirect = false; + p.sections = 1; + + if(info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect) + { + p.indirect = true; + p.sections = b->info()->tensor_shape()[2] * b->info()->tensor_shape()[3]; + } + else + { + p.M = d->info()->tensor_shape().y(); + p.multis = b->info()->tensor_shape().z(); + p.batches = d->info()->tensor_shape().total_size_upper(2) / p.multis; //COMPMID-1423: Agree on and document the layout of gemm inputs/outputs + } + + // Update M in case of GEMM3D for output + if(info.depth_output_gemm3d != 0) + { + p.M = d->info()->tensor_shape().y() * d->info()->tensor_shape().z(); + p.batches = d->info()->tensor_shape().total_size_upper(3) / p.multis; + } + + return p; +} + arm_gemm::Activation map_to_arm_gemm_activation(const ActivationLayerInfo &act) { arm_gemm::Activation gemm_act; @@ -69,6 +121,29 @@ arm_gemm::Activation map_to_arm_gemm_activation(const ActivationLayerInfo &act) return gemm_act; } +IScheduler::Hints scheduling_hint_heuristic(arm_gemm::GemmMethod method, DataType data_type) +{ + // Schedule assembly kernel + const int granule_threshold = 200; + IScheduler::Hints scheduling_hint = IScheduler::Hints(Window::DimX); + if(method == arm_gemm::GemmMethod::GEMM_INTERLEAVED && data_type == DataType::F32) + { + scheduling_hint = IScheduler::Hints(Window::DimX, IScheduler::StrategyHint::DYNAMIC, granule_threshold); + } + else if(method == arm_gemm::GemmMethod::GEMM_INTERLEAVED_2D && (data_type == DataType::F32 || data_type == DataType::F16 || data_type == DataType::U8 || data_type == DataType::S8)) + { + //GEMM_INTERLEAVED supports 2D parallelism, IScheduler::split_dimensions_all signals to parallelise over all window dimensions + scheduling_hint = IScheduler::Hints(IScheduler::split_dimensions_all, IScheduler::StrategyHint::STATIC, granule_threshold); + } + else if(method == arm_gemm::GemmMethod::QUANTIZE_WRAPPER_2D && (data_type == DataType::QASYMM8 || data_type == DataType::QASYMM8_SIGNED)) + { + //special case for QASYMM8 to support 2D parallelism, scheduler here may be tweaked differently compared to FP32 case + scheduling_hint = IScheduler::Hints(IScheduler::split_dimensions_all, IScheduler::StrategyHint::STATIC, granule_threshold); + } + + return scheduling_hint; +} + template class FallbackTransform : public ITransformWeights { @@ -165,7 +240,7 @@ public: * @param[in] os Output stage meta-data. */ void configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, - arm_gemm::GemmArgs args, const GEMMInfo &gemm_info, + arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info, MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os = {}); /** Set requantization shifts to be used @@ -198,6 +273,16 @@ private: * @param[in] alignment Workspace memory alignment. */ void allocate_workspace(size_t workspace_size, MemoryGroup &memory_group, size_t alignment); + /** Configure the indirect buffer + * + * @param[in] a Input tensor containing the Matrix A. + * @param[in] b Input tensor containing the Matrix B. + * @param[out] d Output tensor to store the result of matrix multiplication. + * @param[in] info GEMM meta-data + */ + void configure_indirect(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info); + /** Prepare the indirect buffer */ + void prepare_indirect_buffer(); /** Assembly Gemm kernel */ std::shared_ptr> _gemm_kernel_asm{ nullptr }; @@ -226,7 +311,7 @@ private: /** Prepared flag */ bool _is_prepared{ false }; /** GEMM meta-data */ - GEMMInfo _gemm_info{}; + AsmGemmInfo _gemm_info{}; /** Weights manager */ IWeightsManager *_weights_manager{ nullptr }; /** Weights transform object */ @@ -239,11 +324,16 @@ private: std::vector left_shifts{}; /** Per channel quantization multipliers */ std::vector _multipliers{}; + /** Indirect buffer */ + std::unique_ptr _indirect_arg{}; + std::unique_ptr _indirect_buf{}; + std::vector _indirect_pad{}; + arm_gemm::ConvolutionParameters _cp{}; }; template -std::tuple Fallback::set_requantize_data(const std::vector &shifts, - const std::vector &multipliers) +std::tuple +Fallback::set_requantize_data(const std::vector &shifts, const std::vector &multipliers) { _multipliers = multipliers; _shifts = shifts; @@ -260,9 +350,123 @@ std::tuple Fallback +void Fallback::prepare_indirect_buffer() +{ + const TypeInput *A_ptr = reinterpret_cast(_a->buffer()); + const int multis = 1; + const int batches = _a->info()->tensor_shape().total_size_upper(3); + const size_t stride_A = _a->info()->strides_in_bytes().y() / sizeof(TypeInput); + const size_t batch_stride_A = _a->info()->strides_in_bytes()[3] / sizeof(TypeInput); + const size_t multi_stride_A = _a->info()->strides_in_bytes()[4] / sizeof(TypeInput); + + const size_t output_hw = _cp.output_height * _cp.output_width; + const int batch_size = _cp.kernel_height * _cp.kernel_width * output_hw * sizeof(TypeInput); + const size_t batch_stride = batch_size / sizeof(TypeInput); + const int multi_size = batch_size * batches; + const size_t multi_stride = multi_size / sizeof(TypeInput); + + for(int64_t m = 0; m < multis; m++) + { + for(int64_t b = 0; b < batches; b++) + { + for(int64_t output_y = 0; output_y < _cp.output_height; output_y++) + { + for(int64_t output_x = 0; output_x < _cp.output_width; output_x++) + { + int64_t output_xy = (output_y * _cp.output_width) + output_x; + + for(int64_t kernel_y = 0; kernel_y < _cp.kernel_height; kernel_y++) + { + for(int64_t kernel_x = 0; kernel_x < _cp.kernel_width; kernel_x++) + { + int64_t input_x = (output_x * _cp.output_stride_w) + kernel_x - _cp.padding_left; + int64_t input_y = (output_y * _cp.output_stride_h) + kernel_y - _cp.padding_top; + int64_t kernel_xy = (kernel_y * _cp.kernel_width) + kernel_x; + int64_t input_xy = (input_y * _cp.input_width) + input_x; + + if(input_x < 0 || input_x >= _cp.input_width || input_y < 0 || input_y >= _cp.input_height) + { + _indirect_buf.get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] = _indirect_pad.data(); + } + else + { + _indirect_buf.get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] = + A_ptr + (m * multi_stride_A + b * batch_stride_A + input_xy * stride_A); + } + } + } + } + } + } + } +} + +template +void Fallback::configure_indirect(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info) +{ + ARM_COMPUTE_ERROR_ON(!(info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect)); + + float zeropad = 0.f; + if(is_data_type_quantized(a->data_type())) + { + zeropad = a->quantization_info().uniform().offset; + } + + const int64_t input_width = static_cast(a->tensor_shape()[1]); + const int64_t input_height = static_cast(a->tensor_shape()[2]); + const int64_t input_channels = static_cast(a->tensor_shape()[0]); + const int64_t kernel_width = static_cast(b->tensor_shape()[2]); + const int64_t kernel_height = static_cast(b->tensor_shape()[3]); + const int64_t output_width = static_cast(d->tensor_shape()[1]); + const int64_t output_height = static_cast(d->tensor_shape()[2]); + + _cp = { input_width, input_height, input_channels, kernel_width, kernel_height, output_width, output_height, + info.ps_info.stride().first, info.ps_info.stride().second, info.padding_top, info.padding_left, zeropad + }; + + if(info.method == AsmConvMethod::Conv) + { + _gemm_kernel_asm->set_convolution_parameters(_cp); + } + + if(info.method == AsmConvMethod::Indirect) + { + const unsigned int multis = 1; + const unsigned int batches = a->tensor_shape().total_size_upper(3); + const unsigned int kernel_hw = _cp.kernel_width * _cp.kernel_height; + const unsigned int output_hw = _cp.output_width * _cp.output_height; + + using TypeInputPtr = TypeInput *; + const int batch_size = kernel_hw * output_hw * sizeof(TypeInputPtr); + const size_t batch_stride = batch_size / sizeof(TypeInputPtr); + const int multi_size = batch_size * batches; + const size_t multi_stride = multi_size / sizeof(TypeInputPtr); + + _indirect_buf = std::unique_ptr(reinterpret_cast(malloc(multi_size * multis))); + _indirect_arg = std::unique_ptr(reinterpret_cast(malloc(sizeof(TypeInput **) * kernel_hw * multis * batches))); + _indirect_pad = std::vector(_cp.input_channels, zeropad); + + // Set indirect argument + int64_t pos = 0; + for(int64_t m = 0; m < multis; m++) + { + for(int64_t b = 0; b < batches; b++) + { + for(int64_t kernel_xy = 0; kernel_xy < kernel_hw; kernel_xy++) + { + (_indirect_arg.get())[pos++] = _indirect_buf.get() + m * multi_stride + b * batch_stride + kernel_xy * output_hw; + } + } + } + + _gemm_kernel_asm->set_indirect_parameters(a->tensor_shape()[0], _indirect_arg.get()); + } +} + template void Fallback::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, - arm_gemm::GemmArgs args, const GEMMInfo &gemm_info, + arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info, MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os) { arm_gemm::GemmConfig gemm_cfg; @@ -325,6 +529,12 @@ void Fallback::configure(const ITensor *a, c static_cast(_pretranspose)->allocator()->init(TensorInfo(TensorShape{ (B_pretranspose_size + alignment /* FIXME: remove alignment after COMPMID-1088 */) }, 1, DataType::S8), alignment); } } + + // Handle indirect GEMM convolution + if(gemm_info.method == AsmConvMethod::Conv || gemm_info.method == AsmConvMethod::Indirect) + { + configure_indirect(a->info(), b->info(), d->info(), gemm_info); + } } template @@ -365,6 +575,11 @@ void Fallback::prepare() } } + if(_gemm_info.method == AsmConvMethod::Indirect) + { + prepare_indirect_buffer(); + } + _is_prepared = true; } } @@ -387,23 +602,23 @@ bool Fallback::is_configured() const template void Fallback::run() { - const int lda = _a->info()->strides_in_bytes().y() / sizeof(TypeInput); + int lda = _a->info()->strides_in_bytes().y() / sizeof(TypeInput); int ldb = 0; const int ldd = _d->info()->strides_in_bytes().y() / sizeof(TypeOutput); - const size_t a_batch_idx = _gemm_info.reinterpret_input_as_3d() != 0 ? 3 : 2; + const size_t a_batch_idx = _gemm_info.reinterpret_input_as_3d != 0 ? 3 : 2; const size_t a_multi_idx = a_batch_idx + 1; - const size_t d_batch_idx = _gemm_info.depth_output_gemm3d() != 0 ? 3 : 2; + const size_t d_batch_idx = _gemm_info.depth_output_gemm3d != 0 ? 3 : 2; const size_t d_multi_idx = d_batch_idx + 1; - const int batch_stride_a = _a->info()->strides_in_bytes()[a_batch_idx] / sizeof(TypeInput); + int batch_stride_a = _a->info()->strides_in_bytes()[a_batch_idx] / sizeof(TypeInput); const int batch_stride_d = _d->info()->strides_in_bytes()[d_batch_idx] / sizeof(TypeOutput); - const int multi_stride_a = _a->info()->strides_in_bytes()[a_multi_idx] / sizeof(TypeInput); + int multi_stride_a = _a->info()->strides_in_bytes()[a_multi_idx] / sizeof(TypeInput); int multi_stride_b = 0; const int multi_stride_d = _d->info()->strides_in_bytes()[d_multi_idx] / sizeof(TypeOutput); - const auto in0_ptr = reinterpret_cast(_a->buffer() + _a->info()->offset_first_element_in_bytes()); + auto in0_ptr = reinterpret_cast(_a->buffer() + _a->info()->offset_first_element_in_bytes()); const TypeInput *in1_ptr = nullptr; auto out_ptr = reinterpret_cast(_d->buffer() + _d->info()->offset_first_element_in_bytes()); @@ -415,25 +630,7 @@ void Fallback::run() in1_ptr = reinterpret_cast(_b->buffer() + _b->info()->offset_first_element_in_bytes()); } - IScheduler::Hints scheduling_hint = IScheduler::Hints(Window::DimX); - if(_kernel_info.method == arm_gemm::GemmMethod::GEMM_INTERLEAVED && _d->info()->data_type() == DataType::F32) - { - const int granule_threshold = 200; - scheduling_hint = IScheduler::Hints(Window::DimX, IScheduler::StrategyHint::DYNAMIC, granule_threshold); - } - else if(_kernel_info.method == arm_gemm::GemmMethod::GEMM_INTERLEAVED_2D && (_d->info()->data_type() == DataType::F32 || _d->info()->data_type() == DataType::F16 - || _d->info()->data_type() == DataType::U8 || _d->info()->data_type() == DataType::S8)) - { - //GEMM_INTERLEAVED supports 2D parallelism, IScheduler::split_dimensions_all signals to parallelise over all window dimensions - const int granule_threshold = 200; - scheduling_hint = IScheduler::Hints(IScheduler::split_dimensions_all, IScheduler::StrategyHint::STATIC, granule_threshold); - } - else if(_kernel_info.method == arm_gemm::GemmMethod::QUANTIZE_WRAPPER_2D && (_d->info()->data_type() == DataType::QASYMM8 || _d->info()->data_type() == DataType::QASYMM8_SIGNED)) - { - //special case for QASYMM8 to support 2D parallelism, scheduler here may be tweaked differently compared to FP32 case - const int granule_threshold = 200; - scheduling_hint = IScheduler::Hints(IScheduler::split_dimensions_all, IScheduler::StrategyHint::STATIC, granule_threshold); - } + const auto scheduling_hint = scheduling_hint_heuristic(_kernel_info.method, _d->info()->data_type()); // Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads if(_workspace.buffer() != nullptr) @@ -458,57 +655,67 @@ void Fallback::run() // Prepare assembly kernel prepare(); - TypeOutput *bias = nullptr; // Setup up matrix bias in the assembly kernel, it's just a pointer to matrix C. + TypeOutput *bias = nullptr; if(_c && _c->info()->data_type() != DataType::S32) { bias = reinterpret_cast(_c->buffer() + _c->info()->offset_first_element_in_bytes()); } + + if(_gemm_info.method == AsmConvMethod::Indirect) + { + in0_ptr = nullptr; + lda = 0; + batch_stride_a = 0; + multi_stride_a = 0; + } + // Set gemm parameters _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, in1_ptr, ldb, multi_stride_b, out_ptr, ldd, batch_stride_d, multi_stride_d, bias, 0); - // Schedule assembly kernel + // Schedule NEScheduler::get().schedule(_optimised_kernel.get(), scheduling_hint); } template void create_arm_gemm(std::unique_ptr &arm_gemm, MemoryGroup &memory_group, - const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, arm_gemm::Activation activation, const GEMMInfo &gemm_info, + const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, arm_gemm::Activation activation, const AsmGemmInfo &info, IWeightsManager *weights_manager) { - INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d, gemm_info); - const CPUInfo &ci = NEScheduler::get().cpu_info(); - unsigned int num_threads = NEScheduler::get().num_threads(); + Params p = extract_parameters(a, b, d, info); + const CPUInfo &ci = NEScheduler::get().cpu_info(); + unsigned int num_threads = NEScheduler::get().num_threads(); - arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, activation, num_threads); + arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads); // Create arm_gemm fallback auto fallback = support::cpp14::make_unique>(); - fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager); + fallback->configure(a, b, c, d, args, info, memory_group, weights_manager); arm_gemm = std::move(fallback); } template void create_arm_gemm_quant(std::unique_ptr &arm_gemm, MemoryGroup &memory_group, - const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, arm_gemm::Activation activation, const GEMMInfo &gemm_info, + const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, arm_gemm::Activation activation, const AsmGemmInfo &info, IWeightsManager *weights_manager) { ARM_COMPUTE_UNUSED(activation); - INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d, gemm_info); - const CPUInfo &ci = NEScheduler::get().cpu_info(); - unsigned int num_threads = NEScheduler::get().num_threads(); + Params p = extract_parameters(a, b, d, info); + const CPUInfo &ci = NEScheduler::get().cpu_info(); + unsigned int num_threads = NEScheduler::get().num_threads(); - arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, activation, num_threads); + arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads); // Create arm_gemm fallback auto fallback = support::cpp14::make_unique>(); // Configure requantization info - const int32_t a_offset = -a->info()->quantization_info().uniform().offset; - const int32_t b_offset = -b->info()->quantization_info().uniform().offset; - const GEMMLowpOutputStageInfo os_info = gemm_info.gemmlowp_output_stage(); + const int32_t negation = info.negated_offsets ? 1 : -1; + const int32_t a_offset = -a->info()->quantization_info().uniform().offset * negation; + const int32_t b_offset = -b->info()->quantization_info().uniform().offset * negation; + const GEMMLowpOutputStageInfo os_info = info.output_stage; arm_gemm::Requantize32 gemm_requant_info{}; if(os_info.gemmlowp_shifts.size() > 1) @@ -530,7 +737,7 @@ void create_arm_gemm_quant(std::unique_ptr &a } // Configure fallback - fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager, gemm_requant_info); + fallback->configure(a, b, c, d, args, info, memory_group, weights_manager, gemm_requant_info); arm_gemm = std::move(fallback); } @@ -541,14 +748,13 @@ NEGEMMAssemblyDispatch::NEGEMMAssemblyDispatch(std::shared_ptr m { } -Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const GEMMInfo &gemm_info) +Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info) { - ARM_COMPUTE_UNUSED(c); + ARM_COMPUTE_UNUSED(c, info); ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(a, b, d); ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a); ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a); - ARM_COMPUTE_RETURN_ERROR_ON(!gemm_info.pretranpose_B()); #ifndef __aarch64__ ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->element_size() == 1, "8bit integer types only supported for aarch64"); #endif /* __aarch64__ */ @@ -579,13 +785,13 @@ bool NEGEMMAssemblyDispatch::is_activation_supported(const ActivationLayerInfo & return act.type != arm_gemm::Activation::Type::None; } -void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, const GEMMInfo &gemm_info) +void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, const AsmGemmInfo &info) { ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d); - arm_gemm::Activation act = map_to_arm_gemm_activation(gemm_info.activation_info()); + arm_gemm::Activation act = map_to_arm_gemm_activation(info.activation_info); //If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured() - if(!NEGEMMAssemblyDispatch::validate(a->info(), b->info(), c != nullptr ? c->info() : nullptr, d->info(), gemm_info)) + if(!NEGEMMAssemblyDispatch::validate(a->info(), b->info(), c != nullptr ? c->info() : nullptr, d->info(), info)) { return; } @@ -593,40 +799,40 @@ void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const switch(a->info()->data_type()) { case DataType::F32: - create_arm_gemm(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager); + create_arm_gemm(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager); break; #ifdef __aarch64__ case DataType::U8: case DataType::QASYMM8: if(d->info()->data_type() == DataType::S32) { - create_arm_gemm(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager); + create_arm_gemm(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager); } else { - create_arm_gemm_quant(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager); + create_arm_gemm_quant(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager); } break; case DataType::S8: case DataType::QASYMM8_SIGNED: if(d->info()->data_type() == DataType::S32) { - create_arm_gemm(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager); + create_arm_gemm(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager); } else { - create_arm_gemm_quant(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager); + create_arm_gemm_quant(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager); } break; #endif /* __aarch64__ */ #if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) case DataType::BFLOAT16: - create_arm_gemm(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager); + create_arm_gemm(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager); break; #endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: - create_arm_gemm(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager); + create_arm_gemm(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager); break; #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ default: diff --git a/src/runtime/NEON/functions/NEGEMMConv2d.cpp b/src/runtime/NEON/functions/NEGEMMConv2d.cpp new file mode 100644 index 0000000000..642b084fb4 --- /dev/null +++ b/src/runtime/NEON/functions/NEGEMMConv2d.cpp @@ -0,0 +1,167 @@ +/* + * Copyright (c) 2020 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "arm_compute/runtime/NEON/functions/NEGEMMConv2d.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" +#include "arm_compute/core/utils/quantization/AsymmHelpers.h" +#include "arm_compute/runtime/NEON/NEScheduler.h" +#include +namespace arm_compute +{ +namespace +{ +GEMMLowpOutputStageInfo calculate_output_stage_metadata(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const ActivationLayerInfo &act) +{ + // Since we need negative offsets for computing convolution, we need to change QuantizationInfo() + // Extract and negate input and weights offset + const QuantizationInfo iqinfo = input->quantization_info(); + const QuantizationInfo wqinfo = weights->quantization_info(); + const QuantizationInfo oqinfo = (output->total_size() == 0) ? iqinfo : output->quantization_info(); + const UniformQuantizationInfo uoqinfo = oqinfo.uniform(); + const DataType data_type = input->data_type(); + // Merge activation with output stage + const std::set supported_acts = { ActivationLayerInfo::ActivationFunction::RELU, + ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, + ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU + }; + PixelValue type_min{}; + PixelValue type_max{}; + std::tie(type_min, type_max) = get_min_max(data_type); + int32_t min_activation = type_min.get(); + int32_t max_activation = type_max.get(); + if(supported_acts.count(act.activation()) != 0) + { + std::tie(min_activation, max_activation) = get_quantized_activation_min_max(act, data_type, uoqinfo); + } + GEMMLowpOutputStageInfo os_info; + os_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT; + os_info.gemmlowp_offset = uoqinfo.offset; + os_info.gemmlowp_min_bound = min_activation; + os_info.gemmlowp_max_bound = max_activation; + os_info.is_quantized_per_channel = (weights->data_type() == DataType::QSYMM8_PER_CHANNEL); + quantization::calculate_quantized_multipliers(iqinfo, wqinfo, oqinfo, os_info); + return os_info; +} +AsmGemmInfo init_assembly_metadata(const Conv2dInfo &info, bool is_indirect) +{ + AsmGemmInfo asm_info; + asm_info.method = is_indirect ? AsmConvMethod::Indirect : AsmConvMethod::Conv; + asm_info.ps_info = info.conv_info; + asm_info.activation_info = info.act_info; + asm_info.depth_output_gemm3d = true; + asm_info.reinterpret_input_as_3d = true; + asm_info.padding_top = info.conv_info.pad_top(); + asm_info.padding_left = info.conv_info.pad_left(); + asm_info.padding_value = 0.f; + asm_info.negated_offsets = false; + return asm_info; +} +} // namespace + +NEGEMMConv2d::NEGEMMConv2d(const std::shared_ptr &memory_manager) + : _gemm_asm_func(memory_manager), _activation_func(), _weights_permute_func(), _original_weights(nullptr), _permuted_weights(), _is_prepared(false), _run_activation(false) +{ +} +void NEGEMMConv2d::configure(ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const Conv2dInfo &info) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output); + ARM_COMPUTE_ERROR_THROW_ON(NEGEMMConv2d::validate(input->info(), + weights->info(), + biases != nullptr ? biases->info() : nullptr, + output->info(), + info)); + _original_weights = weights; + _weights_permute_func.configure(weights, &_permuted_weights, PermutationVector{ 3, 0, 1, 2 }); + + // Configure assembly dispatch + AsmGemmInfo asm_info = init_assembly_metadata(info, false); + if(is_data_type_quantized(input->info()->data_type())) + { + asm_info.output_stage = calculate_output_stage_metadata(input->info(), weights->info(), output->info(), info.act_info); + } + _gemm_asm_func.configure(input, &_permuted_weights, biases, output, asm_info); + + // Configure activation + if(info.act_info.enabled() && !_gemm_asm_func.is_activation_supported(info.act_info)) + { + _activation_func.configure(output, nullptr, info.act_info); + _run_activation = true; + } +} +Status NEGEMMConv2d::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const Conv2dInfo &info) +{ + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::BFLOAT16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8_PER_CHANNEL, DataType::BFLOAT16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, weights); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(info.num_groups > 1, "Grouping (num_groups != 1) is not supported on NEON"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_layout() != DataLayout::NHWC, "Data layout supported is NHWC"); + const DataType data_type = input->data_type(); + const TensorShape i_shape = input->tensor_shape(); + const TensorShape w_shape = weights->tensor_shape(); + ARM_COMPUTE_RETURN_ERROR_ON(w_shape[0] != i_shape[0]); + ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4); + // Validate biases + if(biases != nullptr) + { + if(is_data_type_quantized_asymmetric(data_type)) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::S32); + } + else if(data_type == DataType::BFLOAT16) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::F32); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, biases); + } + ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != weights->dimension(3)); + ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1); + } + + AsmGemmInfo asm_info = init_assembly_metadata(info, false); + ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMAssemblyDispatch::validate(input, weights, biases, output, asm_info)); + return Status{}; +} +void NEGEMMConv2d::run() +{ + prepare(); + + _gemm_asm_func.run(); + if(_run_activation) + { + _activation_func.run(); + } +} +void NEGEMMConv2d::prepare() +{ + if(!_is_prepared) + { + _permuted_weights.allocator()->allocate(); + _weights_permute_func.run(); + _original_weights->mark_as_unused(); + _is_prepared = true; + } +} +} // namespace arm_compute diff --git a/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp deleted file mode 100644 index 09637dd2d6..0000000000 --- a/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Copyright (c) 2017-2020 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#include "arm_compute/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.h" - -#include "arm_compute/core/Error.h" -#include "arm_compute/core/Helpers.h" -#include "arm_compute/core/ITensor.h" -#include "arm_compute/core/TensorInfo.h" -#include "arm_compute/core/Types.h" -#include "arm_compute/core/Validate.h" -#include "arm_compute/runtime/NEON/NEScheduler.h" -#include "arm_compute/runtime/TensorAllocator.h" -#include "src/core/NEON/kernels/NEGEMMInterleave4x4Kernel.h" -#include "src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.h" -#include "src/core/NEON/kernels/NEGEMMTranspose1xWKernel.h" -#include "support/MemorySupport.h" - -namespace arm_compute -{ -NEGEMMLowpAssemblyMatrixMultiplyCore::~NEGEMMLowpAssemblyMatrixMultiplyCore() = default; - -NEGEMMLowpAssemblyMatrixMultiplyCore::NEGEMMLowpAssemblyMatrixMultiplyCore(std::shared_ptr memory_manager) - : _memory_group(memory_manager), _asm_glue(memory_manager), _mm_kernel(nullptr), _mtx_a_reshape_kernel(nullptr), _mtx_b_reshape_kernel(nullptr), _tmp_a(), _tmp_b() -{ -} - -void NEGEMMLowpAssemblyMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *output) -{ - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::U8, DataType::S8); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U32, DataType::S32); - ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(a, b); - ARM_COMPUTE_ERROR_ON_MSG((a)->info()->dimension(0) != (b)->info()->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B"); - ARM_COMPUTE_ERROR_ON_MSG((a)->info()->dimension(1) != (output)->info()->dimension(1), "The output matrix must have the same number of rows as the matrix A"); - ARM_COMPUTE_ERROR_ON_MSG((b)->info()->dimension(0) != (output)->info()->dimension(0), "The output matrix must have the same number of columns as the matrix B"); - - bool run_optimised = false; - switch(a->info()->data_type()) - { - case DataType::S8: - case DataType::QASYMM8: - case DataType::U8: - { - _asm_glue.configure(a, b, c, output, GEMMInfo(false, false, true)); - run_optimised = _asm_glue.is_configured(); - break; - } - default: - { - ARM_COMPUTE_ERROR("Datatype not supported"); - break; - } - } - if(!run_optimised) - { - // The interleaved output matrix will have the following shape: [ a_height * 4, ceil(a_width / 4.0f) ] - TensorShape shape_tmp_a = a->info()->tensor_shape(); - shape_tmp_a.set(0, a->info()->dimension(0) * 4); - shape_tmp_a.set(1, std::ceil(a->info()->dimension(1) / 4.f)); - - // The transpose1xW output matrix will have the following shape: [ b_height * 16, ceil(b_width / 16.0f) ] - TensorShape shape_tmp_b = b->info()->tensor_shape(); - shape_tmp_b.set(0, b->info()->dimension(1) * 16); - shape_tmp_b.set(1, std::ceil(b->info()->dimension(0) / 16.f)); - - TensorInfo info_a(shape_tmp_a, 1, a->info()->data_type()); - TensorInfo info_b(shape_tmp_b, 1, b->info()->data_type()); - _tmp_a.allocator()->init(info_a); - _tmp_b.allocator()->init(info_b); - _memory_group.manage(&_tmp_a); - _memory_group.manage(&_tmp_b); - - // Configure interleave kernel - { - auto k = arm_compute::support::cpp14::make_unique(); - k->configure(a, &_tmp_a); - _mtx_a_reshape_kernel = std::move(k); - } - - // Configure transpose kernel - { - auto k = arm_compute::support::cpp14::make_unique(); - k->configure(b, &_tmp_b); - _mtx_b_reshape_kernel = std::move(k); - } - - // Configure matrix multiply kernel - { - auto k = arm_compute::support::cpp14::make_unique(); - k->configure(&_tmp_a, &_tmp_b, output); - _mm_kernel = std::move(k); - } - - // Allocate tensors - _tmp_a.allocator()->allocate(); - _tmp_b.allocator()->allocate(); - } -} - -void NEGEMMLowpAssemblyMatrixMultiplyCore::run() -{ - MemoryGroupResourceScope scope_mg(_memory_group); - if(_mtx_a_reshape_kernel) - { - NEScheduler::get().schedule(_mtx_a_reshape_kernel.get(), Window::DimY); - } - - if(_mtx_b_reshape_kernel) - { - NEScheduler::get().schedule(_mtx_b_reshape_kernel.get(), Window::DimY); - } - - if(_asm_glue.is_configured()) - { - _asm_glue.run(); - } - else - { - NEScheduler::get().schedule(_mm_kernel.get(), Window::DimY); - } -} -} // namespace arm_compute \ No newline at end of file diff --git a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp index 9050427b34..df8eaacf47 100644 --- a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp +++ b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp @@ -47,6 +47,21 @@ namespace arm_compute { +namespace +{ +AsmGemmInfo init_assembly_metadata(const GEMMInfo &info) +{ + AsmGemmInfo asm_info; + asm_info.method = AsmConvMethod::Im2Col; + asm_info.reinterpret_input_as_3d = info.reinterpret_input_as_3d(); + asm_info.depth_output_gemm3d = info.depth_output_gemm3d(); + asm_info.activation_info = info.activation_info(); + asm_info.output_stage = info.gemmlowp_output_stage(); + + return asm_info; +} +} // namespace + using namespace arm_compute::misc::shape_calculator; NEGEMMLowpMatrixMultiplyCore::~NEGEMMLowpMatrixMultiplyCore() = default; @@ -120,6 +135,8 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b, _mm_result_s32.allocator()->init(info_mm_result_s32); } + // Initialize assembly kernel meta-data + const AsmGemmInfo asm_info = init_assembly_metadata(gemm_info); #ifdef __aarch64__ switch(a->info()->data_type()) { @@ -130,12 +147,12 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b, { if(is_data_type_quantized_asymmetric(a_to_use->info()->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT) { - _asm_glue.configure(a_to_use, b, c, output, gemm_info); + _asm_glue.configure(a_to_use, b, c, output, asm_info); _fused_assembly_path = _asm_glue.is_configured(); } else { - _asm_glue.configure(a_to_use, b, nullptr, _fuse_output_stage ? &_mm_result_s32 : output, gemm_info); + _asm_glue.configure(a_to_use, b, nullptr, _fuse_output_stage ? &_mm_result_s32 : output, asm_info); } _assembly_path = _asm_glue.is_configured(); break; @@ -346,17 +363,20 @@ Status NEGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso matrix_a_info = &signed_a; } + // Initialize assembly kernel meta-data + const AsmGemmInfo asm_info = init_assembly_metadata(info); + // Check if we need to run the optimized assembly kernel bool run_optimised = false; bool run_optimised_requantized = false; if(is_data_type_quantized_asymmetric(a_to_use->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT) { - run_optimised = bool(NEGEMMAssemblyDispatch::validate(a_to_use, b, c, output, gemm_info)); + run_optimised = bool(NEGEMMAssemblyDispatch::validate(a_to_use, b, c, output, asm_info)); run_optimised_requantized = run_optimised; } else { - run_optimised = bool(NEGEMMAssemblyDispatch::validate(a_to_use, b, nullptr, fuse_output_stage ? &mm_result_s32_info : output, gemm_info)); + run_optimised = bool(NEGEMMAssemblyDispatch::validate(a_to_use, b, nullptr, fuse_output_stage ? &mm_result_s32_info : output, asm_info)); } if(run_optimised) diff --git a/src/runtime/NEON/functions/NESimpleAssemblyFunction.cpp b/src/runtime/NEON/functions/NESimpleAssemblyFunction.cpp deleted file mode 100644 index d165b2235c..0000000000 --- a/src/runtime/NEON/functions/NESimpleAssemblyFunction.cpp +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (c) 2018-2020 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#include "src/runtime/NEON/functions/NESimpleAssemblyFunction.h" - -#include "arm_compute/core/Validate.h" -#include "arm_compute/runtime/NEON/NEScheduler.h" - -using namespace arm_compute; - -NESimpleAssemblyFunction::NESimpleAssemblyFunction() // NOLINT - : _kernel() -{ -} - -void NESimpleAssemblyFunction::run() -{ - NEScheduler::get().schedule(_kernel.get(), Window::DimX); -} - -void NESimpleAssemblyFunction::configure(std::unique_ptr kernel) -{ - ARM_COMPUTE_ERROR_ON_NULLPTR(kernel.get()); - _kernel = std::move(kernel); - ARM_COMPUTE_ERROR_ON_WINDOW_DIMENSIONS_GTE(_kernel->window(), 1); -} diff --git a/src/runtime/NEON/functions/NESimpleAssemblyFunction.h b/src/runtime/NEON/functions/NESimpleAssemblyFunction.h deleted file mode 100644 index e9be54d35f..0000000000 --- a/src/runtime/NEON/functions/NESimpleAssemblyFunction.h +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright (c) 2018-2020 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#ifndef ARM_COMPUTE_NESIMPLEASSEMBLYFUNCTION_H -#define ARM_COMPUTE_NESIMPLEASSEMBLYFUNCTION_H - -#include "arm_compute/runtime/IFunction.h" -#include "src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h" - -#include - -namespace arm_compute -{ -/** Basic interface for functions which have a single NEON GEMM wrapper kernel to run */ -class NESimpleAssemblyFunction : public IFunction -{ -public: - /** Constructor */ - NESimpleAssemblyFunction(); - - /** Configure the function with the kernel to run - * - * @param[in] kernel GEMM Wrapper kernel configured and ready to run - * - * @note The kernel is expected to have a 1D window. The function will multi-thread this window across the X dimension. - */ - void configure(std::unique_ptr kernel); - - // Inherited methods overridden: - void run() override final; - -protected: - std::unique_ptr _kernel; /**< Kernel to run */ -}; -} //namespace arm_compute -#endif /*ARM_COMPUTE_NESIMPLEASSEMBLYFUNCTION_H */ -- cgit v1.2.1