From 2e5fd637205770ec5e11096e6e19b8efc67d544e Mon Sep 17 00:00:00 2001 From: SiCongLi Date: Mon, 2 Mar 2020 15:39:15 +0000 Subject: COMPMID-3098 Fuse Relu and Bounded Relu with FullyConnected NEON Change-Id: Id28062445590d6c06b35f7d7434eb38393ae94a7 Signed-off-by: SiCongLi Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2875 Comments-Addressed: Arm Jenkins Reviewed-by: Michele Di Giorgio Reviewed-by: Gian Marco Iodice Tested-by: Arm Jenkins --- arm_compute/core/Types.h | 8 ++ .../runtime/NEON/functions/NEFullyConnectedLayer.h | 8 +- .../NEON/functions/NEGEMMLowpMatrixMultiplyCore.h | 14 ++- examples/graph_deepspeech_v0_4_1.cpp | 7 +- src/graph/mutators/NodeFusionMutator.cpp | 6 +- src/graph/mutators/SyntheticDataTypeMutator.cpp | 6 +- .../NEON/functions/NEFullyConnectedLayer.cpp | 140 +++++++++++++-------- src/runtime/NEON/functions/NEGEMM.cpp | 14 ++- .../functions/NEGEMMLowpMatrixMultiplyCore.cpp | 106 ++++++++-------- tests/datasets/FullyConnectedLayerDataset.h | 13 +- tests/validation/NEON/FullyConnectedLayer.cpp | 73 +++++++++-- 11 files changed, 258 insertions(+), 137 deletions(-) diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index b6409879bb..711b68f236 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -2140,6 +2140,14 @@ public: { return _activation_info; } + /** Set activation layer info + * + * @param[in] activation_info ActivationLayerInfo object to set + */ + void set_activation_info(const ActivationLayerInfo &activation_info) + { + _activation_info = activation_info; + } private: bool _is_a_reshaped; diff --git a/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h b/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h index db09da45ee..b14650c0e9 100644 --- a/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h +++ b/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h @@ -162,9 +162,9 @@ public: void prepare() override; private: - void configure_fc_fc(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output); - void configure_conv_fc(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output); - void configure_mm(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output); + void configure_fc_fc(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const ActivationLayerInfo &act); + void configure_conv_fc(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const ActivationLayerInfo &act); + void configure_mm(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const ActivationLayerInfo &act); MemoryGroup _memory_group; IWeightsManager *_weights_manager; @@ -182,7 +182,7 @@ private: bool _are_weights_converted; bool _are_weights_reshaped; bool _is_fc_after_conv; - bool _is_quantized; + bool _is_quantized_asymmetric; bool _is_prepared; }; } // namespace arm_compute diff --git a/arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h b/arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h index 508159eb77..74dedcf4c5 100644 --- a/arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h +++ b/arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -28,9 +28,12 @@ #include "arm_compute/core/NEON/INEKernel.h" #include "arm_compute/core/NEON/kernels/NEConvertQuantizedSignednessKernel.h" #include "arm_compute/core/NEON/kernels/NEConvertQuantizedSignednessKernel.h" +#include "arm_compute/core/NEON/kernels/NEGEMMInterleave4x4Kernel.h" +#include "arm_compute/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.h" #include "arm_compute/core/NEON/kernels/NEGEMMLowpOffsetContributionKernel.h" #include "arm_compute/core/NEON/kernels/NEGEMMLowpOffsetContributionOutputStageKernel.h" #include "arm_compute/core/NEON/kernels/NEGEMMLowpReductionKernel.h" +#include "arm_compute/core/NEON/kernels/NEGEMMTranspose1xWKernel.h" #include "arm_compute/runtime/IFunction.h" #include "arm_compute/runtime/IMemoryManager.h" #include "arm_compute/runtime/MemoryGroup.h" @@ -60,7 +63,7 @@ class NEGEMMLowpMatrixMultiplyCore : public IFunction { public: /** Constructor */ - NEGEMMLowpMatrixMultiplyCore(std::shared_ptr memory_manager = nullptr); + NEGEMMLowpMatrixMultiplyCore(std::shared_ptr memory_manager = nullptr, IWeightsManager *weights_manager = nullptr); /** Prevent instances of this class from being copied (As this class contains pointers) */ NEGEMMLowpMatrixMultiplyCore(const NEGEMMLowpMatrixMultiplyCore &) = delete; /** Default move constructor */ @@ -109,10 +112,11 @@ public: private: MemoryGroup _memory_group; + IWeightsManager *_weights_manager; NEGEMMAssemblyDispatch _asm_glue; - std::unique_ptr _mm_kernel; - std::unique_ptr _mtx_a_reshape_kernel; - std::unique_ptr _mtx_b_reshape_kernel; + NEGEMMLowpMatrixMultiplyKernel _mm_kernel; + NEGEMMInterleave4x4Kernel _mtx_a_reshape_kernel; + NEGEMMTranspose1xWKernel _mtx_b_reshape_kernel; NEGEMMLowpMatrixAReductionKernel _mtx_a_reduction_kernel; NEGEMMLowpMatrixBReductionKernel _mtx_b_reduction_kernel; NEGEMMLowpOffsetContributionKernel _offset_contribution_kernel; diff --git a/examples/graph_deepspeech_v0_4_1.cpp b/examples/graph_deepspeech_v0_4_1.cpp index cc65bf309d..b655452391 100644 --- a/examples/graph_deepspeech_v0_4_1.cpp +++ b/examples/graph_deepspeech_v0_4_1.cpp @@ -208,9 +208,10 @@ public: // Finalize graph GraphConfig config; - config.num_threads = common_params.threads; - config.use_tuner = common_params.enable_tuner; - config.tuner_file = common_params.tuner_file; + config.num_threads = common_params.threads; + config.use_tuner = common_params.enable_tuner; + config.tuner_file = common_params.tuner_file; + config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8); graph.finalize(common_params.target, config); diff --git a/src/graph/mutators/NodeFusionMutator.cpp b/src/graph/mutators/NodeFusionMutator.cpp index 151a8bfa03..273e6ce746 100644 --- a/src/graph/mutators/NodeFusionMutator.cpp +++ b/src/graph/mutators/NodeFusionMutator.cpp @@ -301,10 +301,6 @@ void NodeFusionMutator::mutate(Graph &g) { return true; }; - auto cl_target_prec = [](INode & n) - { - return n.assigned_target() == Target::CL; - }; auto qs8_prec = [&g](INode & n) { ARM_COMPUTE_ERROR_ON(n.output(0) == nullptr); @@ -322,7 +318,7 @@ void NodeFusionMutator::mutate(Graph &g) detail::fuse_layer(g, empty_prec, detail::fuse_node_with_activation, supported_fused_activations); detail::fuse_layer(g, empty_prec, detail::fuse_node_with_activation, supported_fused_activations); detail::fuse_layer(g, qs8_prec, detail::fuse_node_with_activation, supported_fused_activations); - detail::fuse_layer(g, cl_target_prec, detail::fuse_node_with_activation, supported_fused_activations); + detail::fuse_layer(g, empty_prec, detail::fuse_node_with_activation, supported_fused_activations); detail::fuse_layer(g, empty_prec, detail::fuse_convolution_with_batch_normalization); detail::fuse_layer(g, empty_prec, detail::fuse_depthwise_convolution_with_batch_normalization); } diff --git a/src/graph/mutators/SyntheticDataTypeMutator.cpp b/src/graph/mutators/SyntheticDataTypeMutator.cpp index b318df956e..0a9f5058dd 100644 --- a/src/graph/mutators/SyntheticDataTypeMutator.cpp +++ b/src/graph/mutators/SyntheticDataTypeMutator.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 ARM Limited. + * Copyright (c) 2019-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -175,6 +175,10 @@ void convert_special_tensors(Graph &g) { tensor->desc().quant_info = QuantizationInfo(1.f / 128.f, 128); } + else if(act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::LOGISTIC) + { + tensor->desc().quant_info = QuantizationInfo(1.f / 256.f, 0); + } return true; }; diff --git a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp index b5f406da8d..6e398ac1d1 100644 --- a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp +++ b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp @@ -39,7 +39,68 @@ using namespace arm_compute::misc::shape_calculator; namespace { -Status validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output) +// Get min, max bound of a quantized assymetric output tensor, with the effect of fused activation +std::pair get_quantized_asymmetric_output_min_max(const QuantizationInfo &q_info, const ActivationLayerInfo &act_info, DataType data_type) +{ + PixelValue type_min{}; + PixelValue type_max{}; + std::tie(type_min, type_max) = get_min_max(data_type); + const UniformQuantizationInfo q_unif = q_info.uniform(); + + if(act_info.enabled()) + { + switch(act_info.activation()) + { + case ActivationLayerInfo::ActivationFunction::RELU: + type_min = PixelValue(q_unif.offset); + break; + case ActivationLayerInfo::ActivationFunction::BOUNDED_RELU: + type_min = PixelValue(q_unif.offset); + type_max = PixelValue(act_info.a(), data_type, q_info); + break; + case ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU: + type_min = PixelValue(act_info.b(), data_type, q_info); + type_max = PixelValue(act_info.a(), data_type, q_info); + break; + default: + ARM_COMPUTE_ERROR("Activation function not supported."); + break; + } + } + + return std::make_pair(type_min, type_max); +} + +Status get_gemmlowp_output_stage_info(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const ActivationLayerInfo &act, + GEMMLowpOutputStageInfo &gemmlowp_output_stage_info) +{ + const auto data_type = input->data_type(); + const QuantizationInfo oq_info = output->quantization_info(); + const UniformQuantizationInfo iq_unif = input->quantization_info().uniform(); + const UniformQuantizationInfo wq_unif = weights->quantization_info().uniform(); + const UniformQuantizationInfo oq_unif = oq_info.uniform(); + + float multiplier = (iq_unif.scale * wq_unif.scale) / oq_unif.scale; + int32_t output_multiplier; + int32_t output_shift; + + ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift)); + + PixelValue type_min{}; + PixelValue type_max{}; + std::tie(type_min, type_max) = get_quantized_asymmetric_output_min_max(oq_info, act, data_type); + + gemmlowp_output_stage_info.gemmlowp_multiplier = output_multiplier; + gemmlowp_output_stage_info.gemmlowp_shift = output_shift; + gemmlowp_output_stage_info.gemmlowp_offset = oq_unif.offset; + gemmlowp_output_stage_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT; + gemmlowp_output_stage_info.gemmlowp_min_bound = type_min.get(); + gemmlowp_output_stage_info.gemmlowp_max_bound = type_max.get(); + + return Status{}; +} + +Status validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const ActivationLayerInfo &act) { if(is_data_type_quantized_asymmetric(input->data_type())) { @@ -48,23 +109,8 @@ Status validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const I const QuantizationInfo input_quantization_info(input->quantization_info().uniform().scale, -input->quantization_info().uniform().offset); const QuantizationInfo weights_quantization_info(weights->quantization_info().uniform().scale, -weights->quantization_info().uniform().offset); - const UniformQuantizationInfo iq_info = input->quantization_info().uniform(); - const UniformQuantizationInfo wq_info = weights->quantization_info().uniform(); - const UniformQuantizationInfo oq_info = output->quantization_info().uniform(); - - float multiplier = (iq_info.scale * wq_info.scale) / oq_info.scale; - int32_t output_multiplier; - int32_t output_shift; - ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift)); - GEMMLowpOutputStageInfo gemmlowp_output_stage_info; - gemmlowp_output_stage_info.gemmlowp_multiplier = output_multiplier; - gemmlowp_output_stage_info.gemmlowp_shift = output_shift; - gemmlowp_output_stage_info.gemmlowp_offset = oq_info.offset; - gemmlowp_output_stage_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT; - const auto min_max_bound = get_min_max(input->data_type()); - gemmlowp_output_stage_info.gemmlowp_min_bound = (std::get<0>(min_max_bound)).get(); - gemmlowp_output_stage_info.gemmlowp_max_bound = (std::get<1>(min_max_bound)).get(); + ARM_COMPUTE_RETURN_ON_ERROR(get_gemmlowp_output_stage_info(input, weights, output, act, gemmlowp_output_stage_info)); GEMMInfo gemm_info; gemm_info.set_gemmlowp_output_stage(gemmlowp_output_stage_info); @@ -99,14 +145,14 @@ Status NEFullyConnectedLayerReshapeWeights::validate(const ITensorInfo *input, c NEFullyConnectedLayer::NEFullyConnectedLayer(std::shared_ptr memory_manager, IWeightsManager *weights_manager) : _memory_group(std::move(memory_manager)), _weights_manager(weights_manager), _flatten_kernel(), _convert_weights(), _convert_weights_managed(), _reshape_weights_function(), - _reshape_weights_managed_function(), _mm_gemm(nullptr, weights_manager), _mm_gemmlowp(), _flatten_output(), _converted_weights_output(), _reshape_weights_output(), _original_weights(nullptr), - _are_weights_converted(true), _are_weights_reshaped(false), _is_fc_after_conv(false), _is_quantized(false), _is_prepared(false) + _reshape_weights_managed_function(), _mm_gemm(nullptr, weights_manager), _mm_gemmlowp(nullptr, weights_manager), _flatten_output(), _converted_weights_output(), _reshape_weights_output(), + _original_weights(nullptr), _are_weights_converted(true), _are_weights_reshaped(false), _is_fc_after_conv(false), _is_quantized_asymmetric(false), _is_prepared(false) { } -void NEFullyConnectedLayer::configure_mm(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output) +void NEFullyConnectedLayer::configure_mm(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const ActivationLayerInfo &act) { - if(_is_quantized) + if(_is_quantized_asymmetric) { // Since we need negative offsets for computing convolution, we need to change QuantizationInfo() // Extract and negate input and weights offset @@ -117,25 +163,13 @@ void NEFullyConnectedLayer::configure_mm(const ITensor *input, const ITensor *we weights->info()->set_quantization_info(QuantizationInfo(weights_quantization_info.uniform().scale, -weights_quantization_info.uniform().offset)); // Configure gemmlowp function and output stage for asymmetric quantized types - const UniformQuantizationInfo iq_info = input->info()->quantization_info().uniform(); - const UniformQuantizationInfo wq_info = weights->info()->quantization_info().uniform(); - const UniformQuantizationInfo oq_info = output->info()->quantization_info().uniform(); - - float multiplier = (iq_info.scale * wq_info.scale) / oq_info.scale; - int32_t output_multiplier; - int32_t output_shift; - quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift); - GEMMLowpOutputStageInfo gemmlowp_output_stage_info; - gemmlowp_output_stage_info.gemmlowp_multiplier = output_multiplier; - gemmlowp_output_stage_info.gemmlowp_shift = output_shift; - gemmlowp_output_stage_info.gemmlowp_offset = oq_info.offset; - gemmlowp_output_stage_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT; - const auto min_max_bound = get_min_max(input->info()->data_type()); - gemmlowp_output_stage_info.gemmlowp_min_bound = (std::get<0>(min_max_bound)).get(); - gemmlowp_output_stage_info.gemmlowp_max_bound = (std::get<1>(min_max_bound)).get(); + const Status status = get_gemmlowp_output_stage_info(input->info(), weights->info(), output->info(), act, gemmlowp_output_stage_info); + ARM_COMPUTE_ERROR_ON(status.error_code() != ErrorCode::OK); + GEMMInfo gemm_info; gemm_info.set_gemmlowp_output_stage(gemmlowp_output_stage_info); + gemm_info.set_activation_info(act); _mm_gemmlowp.configure(input, weights, biases, output, gemm_info); // Revert back QuantizatioInfo as input and weights could be used in other fully connected layers @@ -145,11 +179,13 @@ void NEFullyConnectedLayer::configure_mm(const ITensor *input, const ITensor *we else { // Configure matrix multiply kernel - _mm_gemm.configure(input, weights, biases, output, 1.f, 1.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run */)); + GEMMInfo gemm_info(false, false, true /* Reshape weights only for the first run */); + gemm_info.set_activation_info(act); + _mm_gemm.configure(input, weights, biases, output, 1.f, 1.0f, gemm_info); } } -void NEFullyConnectedLayer::configure_conv_fc(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output) +void NEFullyConnectedLayer::configure_conv_fc(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const ActivationLayerInfo &act) { ARM_COMPUTE_ERROR_ON((weights->info()->dimension(1) != (input->info()->dimension(0) * input->info()->dimension(1) * input->info()->dimension(2)))); @@ -164,18 +200,18 @@ void NEFullyConnectedLayer::configure_conv_fc(const ITensor *input, const ITenso _flatten_kernel.configure(input, &_flatten_output); // Configure matrix multiply kernel - configure_mm(&_flatten_output, weights, biases, output); + configure_mm(&_flatten_output, weights, biases, output, act); // Allocate the output tensor for flatten once all the configure methods have been called _flatten_output.allocator()->allocate(); } -void NEFullyConnectedLayer::configure_fc_fc(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output) +void NEFullyConnectedLayer::configure_fc_fc(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const ActivationLayerInfo &act) { ARM_COMPUTE_ERROR_ON(input->info()->dimension(0) != weights->info()->dimension(1)); // Configure matrix multiply kernel - configure_mm(input, weights, biases, output); + configure_mm(input, weights, biases, output, act); } void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, @@ -189,11 +225,11 @@ void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weigh output->info(), fc_info)); - _are_weights_converted = true; - _are_weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true; - _is_fc_after_conv = true; - _is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type()); - _original_weights = weights; + _are_weights_converted = true; + _are_weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true; + _is_fc_after_conv = true; + _is_quantized_asymmetric = is_data_type_quantized_asymmetric(input->info()->data_type()); + _original_weights = weights; if(_weights_manager) { @@ -263,12 +299,12 @@ void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weigh if(_is_fc_after_conv) { // Fully Connected layer after a Convolution Layer without batches - configure_conv_fc(input, weights_to_use, biases, output); + configure_conv_fc(input, weights_to_use, biases, output, fc_info.activation_info); } else { // Fully Connected layer after a Fully Connected Layer without batches - configure_fc_fc(input, weights_to_use, biases, output); + configure_fc_fc(input, weights_to_use, biases, output, fc_info.activation_info); } _are_weights_reshaped = _are_weights_reshaped || fc_info.retain_internal_weights; @@ -345,7 +381,7 @@ Status NEFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorIn ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(0) != weights_to_use->dimension(1)); } // Validate matrix multiply kernel - ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(input_to_use, weights_to_use, biases, output)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(input_to_use, weights_to_use, biases, output, fc_info.activation_info)); return Status{}; } @@ -363,7 +399,7 @@ void NEFullyConnectedLayer::run() } // Run matrix multiply - if(_is_quantized) + if(_is_quantized_asymmetric) { _mm_gemmlowp.run(); } @@ -436,7 +472,7 @@ void NEFullyConnectedLayer::prepare() release_unused(&_reshape_weights_output); // Prepare GEMM prepare and release unused weights - if(!_is_quantized) + if(!_is_quantized_asymmetric) { _mm_gemm.prepare(); } diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp index be964457fc..873145de12 100644 --- a/src/runtime/NEON/functions/NEGEMM.cpp +++ b/src/runtime/NEON/functions/NEGEMM.cpp @@ -336,25 +336,33 @@ void NEGEMM::prepare() { if(!_is_prepared) { + const bool original_b_managed_by_weights_manager = _weights_manager && _weights_manager->are_weights_managed(_original_b); if(_asm_glue.is_configured()) { - if(!_weights_manager || !_weights_manager->are_weights_managed(_original_b)) + if(!original_b_managed_by_weights_manager) { ARM_COMPUTE_ERROR_ON(!_original_b->is_used()); } _asm_glue.prepare(); + if(!original_b_managed_by_weights_manager) + { + _original_b->mark_as_unused(); + } } else if(_reshape_b_only_on_first_run && !_run_vector_matrix_multiplication && !_asm_glue.is_configured()) { - if(!_weights_manager || !_weights_manager->are_weights_managed(_original_b)) + if(!original_b_managed_by_weights_manager) { ARM_COMPUTE_ERROR_ON(!_original_b->is_used()); } _tmp_b.allocator()->allocate(); NEScheduler::get().schedule(&_transpose_kernel, Window::DimY); - _original_b->mark_as_unused(); + if(!original_b_managed_by_weights_manager) + { + _original_b->mark_as_unused(); + } } _is_prepared = true; diff --git a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp index 3417c72735..a6ebcacf29 100644 --- a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp +++ b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp @@ -27,9 +27,6 @@ #include "arm_compute/core/Helpers.h" #include "arm_compute/core/ITensor.h" #include "arm_compute/core/KernelDescriptors.h" -#include "arm_compute/core/NEON/kernels/NEGEMMInterleave4x4Kernel.h" -#include "arm_compute/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.h" -#include "arm_compute/core/NEON/kernels/NEGEMMTranspose1xWKernel.h" #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/Validate.h" @@ -42,11 +39,12 @@ namespace arm_compute { using namespace arm_compute::misc::shape_calculator; -NEGEMMLowpMatrixMultiplyCore::NEGEMMLowpMatrixMultiplyCore(std::shared_ptr memory_manager) - : _memory_group(memory_manager), _asm_glue(memory_manager), _mm_kernel(nullptr), _mtx_a_reshape_kernel(nullptr), _mtx_b_reshape_kernel(nullptr), _mtx_a_reduction_kernel(), _mtx_b_reduction_kernel(), - _offset_contribution_kernel(), _offset_contribution_output_stage_kernel(), _activation_func(), _convert_to_signed_asymm(), _convert_from_signed_asymm(), _vector_sum_col(), _vector_sum_row(), _tmp_a(), - _tmp_b(), _mm_result_s32(), _signed_a(), _signed_output(), _original_b(nullptr), _a_offset(0), _b_offset(0), _run_vector_matrix_multiplication(false), _assembly_path(false), - _fused_assembly_path(false), _reshape_b_only_on_first_run(false), _is_prepared(false), _fuse_output_stage(false), _run_activation(false), _flip_signedness(false) +NEGEMMLowpMatrixMultiplyCore::NEGEMMLowpMatrixMultiplyCore(std::shared_ptr memory_manager, IWeightsManager *weights_manager) + : _memory_group(memory_manager), _weights_manager(weights_manager), _asm_glue(memory_manager, weights_manager), _mm_kernel(), _mtx_a_reshape_kernel(), _mtx_b_reshape_kernel(), + _mtx_a_reduction_kernel(), _mtx_b_reduction_kernel(), _offset_contribution_kernel(), _offset_contribution_output_stage_kernel(), _activation_func(), _convert_to_signed_asymm(), + _convert_from_signed_asymm(), _vector_sum_col(), _vector_sum_row(), _tmp_a(), _tmp_b(), _mm_result_s32(), _signed_a(), _signed_output(), _original_b(nullptr), _a_offset(0), _b_offset(0), + _run_vector_matrix_multiplication(false), _assembly_path(false), _fused_assembly_path(false), _reshape_b_only_on_first_run(false), _is_prepared(false), _fuse_output_stage(false), + _run_activation(false), _flip_signedness(false) { } @@ -60,10 +58,6 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b, const ITensor *matrix_b = b; GEMMInfo info = gemm_info; - // Clear state - _mtx_a_reshape_kernel = nullptr; - _mtx_b_reshape_kernel = nullptr; - // Set internal variables _a_offset = a->info()->quantization_info().uniform().offset; _b_offset = b->info()->quantization_info().uniform().offset; @@ -158,18 +152,10 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b, } // Configure interleave kernel - { - auto k = arm_compute::support::cpp14::make_unique(); - k->configure(a_to_use, &_tmp_a); - _mtx_a_reshape_kernel = std::move(k); - } + _mtx_a_reshape_kernel.configure(a_to_use, &_tmp_a); // Configure transpose kernel - { - auto k = arm_compute::support::cpp14::make_unique(); - k->configure(b, &_tmp_b); - _mtx_b_reshape_kernel = std::move(k); - } + _mtx_b_reshape_kernel.configure(b, &_tmp_b); } if(!_fused_assembly_path) @@ -209,9 +195,7 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b, // Configure matrix multiply kernel if(!_assembly_path) { - auto k = arm_compute::support::cpp14::make_unique(); - k->configure(matrix_a, matrix_b, &_mm_result_s32); - _mm_kernel = std::move(k); + _mm_kernel.configure(matrix_a, matrix_b, &_mm_result_s32); } _offset_contribution_output_stage_kernel.configure(&_mm_result_s32, @@ -231,21 +215,19 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b, // Configure matrix multiply kernel if(!_assembly_path) { - auto k = arm_compute::support::cpp14::make_unique(); - k->configure(matrix_a, matrix_b, output); - _mm_kernel = std::move(k); + _mm_kernel.configure(matrix_a, matrix_b, output); } // Configure offset contribution kernel _offset_contribution_kernel.configure(output, _a_offset == 0 ? nullptr : &_vector_sum_col, _b_offset == 0 ? nullptr : &_vector_sum_row, a_to_use->info()->dimension(0), _a_offset, _b_offset); } - } - // Configure activation - const ActivationLayerInfo &activation = gemm_info.activation_info(); - _run_activation = activation.enabled() && (!_assembly_path || (_assembly_path && !NEGEMMAssemblyDispatch::is_activation_supported(activation))); - if(_run_activation) - { - _activation_func.configure(output, nullptr, activation); + // Configure activation + const ActivationLayerInfo &activation = gemm_info.activation_info(); + _run_activation = activation.enabled() && (!_assembly_path || (_assembly_path && !NEGEMMAssemblyDispatch::is_activation_supported(activation))); + if(_run_activation) + { + _activation_func.configure(output, nullptr, activation); + } } // Allocate tensors @@ -495,16 +477,6 @@ void NEGEMMLowpMatrixMultiplyCore::run() NEScheduler::get().schedule(&_convert_to_signed_asymm, Window::DimY); } - // Reshape inputs - if(_mtx_a_reshape_kernel) - { - NEScheduler::get().schedule(_mtx_a_reshape_kernel.get(), Window::DimY); - } - if(_mtx_b_reshape_kernel && !_reshape_b_only_on_first_run) - { - NEScheduler::get().schedule(_mtx_b_reshape_kernel.get(), Window::DimY); - } - // Run GEMM if(_asm_glue.is_configured()) { @@ -512,7 +484,18 @@ void NEGEMMLowpMatrixMultiplyCore::run() } else { - NEScheduler::get().schedule(_mm_kernel.get(), Window::DimY); + if(!_run_vector_matrix_multiplication) + { + // Run interleave kernel + NEScheduler::get().schedule(&_mtx_a_reshape_kernel, Window::DimY); + + if(!_reshape_b_only_on_first_run) + { + // Run transpose kernel + NEScheduler::get().schedule(&_mtx_b_reshape_kernel, Window::DimY); + } + } + NEScheduler::get().schedule(&_mm_kernel, Window::DimY); } if(!_fused_assembly_path) @@ -547,8 +530,8 @@ void NEGEMMLowpMatrixMultiplyCore::run() NEScheduler::get().schedule(&_convert_from_signed_asymm, Window::DimY); } - // Run fused activation - if(_run_activation) + // Run fused activation unless already run in the fused assembly + if(_run_activation && !_fused_assembly_path) { _activation_func.run(); } @@ -558,23 +541,36 @@ void NEGEMMLowpMatrixMultiplyCore::prepare() { if(!_is_prepared) { + const bool original_b_managed_by_weights_manager = _weights_manager && _weights_manager->are_weights_managed(_original_b); // Run assembly reshape - if(_asm_glue.is_configured() && _reshape_b_only_on_first_run) + if(_asm_glue.is_configured()) { - ARM_COMPUTE_ERROR_ON(!_original_b->is_used()); + if(!original_b_managed_by_weights_manager) + { + ARM_COMPUTE_ERROR_ON(!_original_b->is_used()); + } _asm_glue.prepare(); - _original_b->mark_as_unused(); + if(!original_b_managed_by_weights_manager) + { + _original_b->mark_as_unused(); + } } // Run non-assembly reshape - else if(_mtx_b_reshape_kernel && _reshape_b_only_on_first_run) + else if(_reshape_b_only_on_first_run && !_run_vector_matrix_multiplication && !_asm_glue.is_configured()) { - ARM_COMPUTE_ERROR_ON(!_original_b->is_used()); + if(!original_b_managed_by_weights_manager) + { + ARM_COMPUTE_ERROR_ON(!_original_b->is_used()); + } // Run reshape kernel and mark original weights tensor as unused _tmp_b.allocator()->allocate(); - NEScheduler::get().schedule(_mtx_b_reshape_kernel.get(), Window::DimY); - _original_b->mark_as_unused(); + NEScheduler::get().schedule(&_mtx_b_reshape_kernel, Window::DimY); + if(!original_b_managed_by_weights_manager) + { + _original_b->mark_as_unused(); + } } // Run matrix B reduction kernel only if _a_offset is not equal to 0 diff --git a/tests/datasets/FullyConnectedLayerDataset.h b/tests/datasets/FullyConnectedLayerDataset.h index 085e9c76b6..06f74ec588 100644 --- a/tests/datasets/FullyConnectedLayerDataset.h +++ b/tests/datasets/FullyConnectedLayerDataset.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -114,6 +114,17 @@ private: std::vector _dst_shapes{}; }; +class FullyConnectedLayerWithActivationDataset final : public FullyConnectedLayerDataset +{ +public: + FullyConnectedLayerWithActivationDataset() + { + // Conv -> FC + add_config(TensorShape(8U, 1U, 1U), TensorShape(8U, 16U), TensorShape(16U), TensorShape(16U)); + // FC -> FC + add_config(TensorShape(1U), TensorShape(1U, 10U), TensorShape(10U), TensorShape(10U)); + } +}; class TinyFullyConnectedLayerDataset final : public FullyConnectedLayerDataset { public: diff --git a/tests/validation/NEON/FullyConnectedLayer.cpp b/tests/validation/NEON/FullyConnectedLayer.cpp index cd2986a1e4..523b3c62f1 100644 --- a/tests/validation/NEON/FullyConnectedLayer.cpp +++ b/tests/validation/NEON/FullyConnectedLayer.cpp @@ -71,8 +71,24 @@ const auto QuantizationData = framework::dataset::make("QuantizationInfo", QuantizationInfo(1.f / 256.f, 10), QuantizationInfo(1.1f, 10), }); +const auto EmptyActivationFunctionDataset = framework::dataset::make("ActivationInfo", +{ + ActivationLayerInfo(), +}); +const auto ActivationFunctionsDataset = framework::dataset::make("ActivationInfo", +{ + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 0.5f), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 0.75f, 0.25f), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH), +}); -const auto ActivationFunctionsDataset = framework::dataset::make("ActivationInfo", ActivationLayerInfo()); +const auto ActivationFunctionsQuantizedDataset = framework::dataset::make("ActivationInfo", +{ + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 0.5f), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 0.75f, 0.25f), +}); } // namespace TEST_SUITE(NEON) @@ -134,7 +150,16 @@ TEST_SUITE(FP16) FIXTURE_DATA_TEST_CASE(RunSmall, NEFullyConnectedLayerFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallFullyConnectedLayerDataset(), FullyConnectedParameters), framework::dataset::make("DataType", DataType::F16)), - ActivationFunctionsDataset)) + EmptyActivationFunctionDataset)) +{ + // Validate output + validate(Accessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16, abs_tolerance_f16); +} +FIXTURE_DATA_TEST_CASE(RunWithActivation, NEFullyConnectedLayerFixture, framework::DatasetMode::PRECOMMIT, combine(combine( + combine(datasets::FullyConnectedLayerWithActivationDataset(), + FullyConnectedParameters), + framework::dataset::make("DataType", DataType::F16)), + ActivationFunctionsDataset)) { // Validate output validate(Accessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16, abs_tolerance_f16); @@ -142,7 +167,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEFullyConnectedLayerFixture, framework:: FIXTURE_DATA_TEST_CASE(RunLarge, NEFullyConnectedLayerFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeFullyConnectedLayerDataset(), FullyConnectedParameters), framework::dataset::make("DataType", DataType::F16)), - ActivationFunctionsDataset)) + EmptyActivationFunctionDataset)) { // Validate output validate(Accessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16, abs_tolerance_f16); @@ -153,14 +178,23 @@ TEST_SUITE_END() TEST_SUITE(FP32) FIXTURE_DATA_TEST_CASE(RunSmall, NEFullyConnectedLayerFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallFullyConnectedLayerDataset(), FullyConnectedParameters), framework::dataset::make("DataType", DataType::F32)), - ActivationFunctionsDataset)) + EmptyActivationFunctionDataset)) +{ + // Validate output + validate(Accessor(_target), _reference, rel_tolerance_f32, 0, abs_tolerance_f32); +} +FIXTURE_DATA_TEST_CASE(RunWithActivation, NEFullyConnectedLayerFixture, framework::DatasetMode::PRECOMMIT, combine(combine( + combine(datasets::FullyConnectedLayerWithActivationDataset(), + FullyConnectedParameters), + framework::dataset::make("DataType", DataType::F32)), + ActivationFunctionsDataset)) { // Validate output validate(Accessor(_target), _reference, rel_tolerance_f32, 0, abs_tolerance_f32); } FIXTURE_DATA_TEST_CASE(RunLarge, NEFullyConnectedLayerFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeFullyConnectedLayerDataset(), FullyConnectedParameters), framework::dataset::make("DataType", DataType::F32)), - ActivationFunctionsDataset)) + EmptyActivationFunctionDataset)) { // Validate output validate(Accessor(_target), _reference, rel_tolerance_f32, 0, abs_tolerance_f32); @@ -178,17 +212,29 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEFullyConnectedLayerQuantizedFixture, FullyConnectedParameters), framework::dataset::make("DataType", DataType::QASYMM8)), QuantizationData), - ActivationFunctionsDataset)) + EmptyActivationFunctionDataset)) { // Validate output validate(Accessor(_target), _reference, tolerance_qasymm8); } + +FIXTURE_DATA_TEST_CASE(RunWithActivation, NEFullyConnectedLayerQuantizedFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine( + combine(datasets::FullyConnectedLayerWithActivationDataset(), + FullyConnectedParameters), + framework::dataset::make("DataType", DataType::QASYMM8)), + QuantizationData), + ActivationFunctionsQuantizedDataset)) +{ + // Validate output + validate(Accessor(_target), _reference, tolerance_qasymm8); +} + FIXTURE_DATA_TEST_CASE(RunLarge, NEFullyConnectedLayerQuantizedFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine( combine(datasets::LargeFullyConnectedLayerDataset(), FullyConnectedParameters), framework::dataset::make("DataType", DataType::QASYMM8)), QuantizationData), - ActivationFunctionsDataset)) + EmptyActivationFunctionDataset)) { // Validate output validate(Accessor(_target), _reference, tolerance_qasymm8); @@ -200,7 +246,18 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEFullyConnectedLayerQuantizedFixture, FullyConnectedParameters), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), QuantizationData), - ActivationFunctionsDataset)) + EmptyActivationFunctionDataset)) +{ + // Validate output + validate(Accessor(_target), _reference, tolerance_qasymm8_signed); +} + +FIXTURE_DATA_TEST_CASE(RunWithActivation, NEFullyConnectedLayerQuantizedFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine( + combine(datasets::FullyConnectedLayerWithActivationDataset(), + FullyConnectedParameters), + framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), + QuantizationData), + ActivationFunctionsQuantizedDataset)) { // Validate output validate(Accessor(_target), _reference, tolerance_qasymm8_signed); -- cgit v1.2.1