From 13b623e575ed2f1096c70560a2db4a9e03cf22f9 Mon Sep 17 00:00:00 2001 From: Milos Puzovic Date: Wed, 27 Jul 2022 17:53:21 +0000 Subject: [ONCPUML-968] Fixed format kernel support in additional APIs Implements required plumbing in order to be able to ask and execute fixed format kernels from NEFullyConnected, NEGEMM and NEGEMMConv2d. These APIs are used to accelerate oneDNN primitives (inner product, matrix multiplication and indirect GEMM respectively) and without changes it would not be possible to call fixed format kernels from those oneDNN primitives. Change-Id: I27534f0491ce28d0ccb98c19f318bd33dcdf2ff5 Signed-off-by: Milos Puzovic Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7999 Reviewed-by: Gian Marco Iodice Reviewed-by: Pablo Marquez Tello Reviewed-by: SiCong Li Reviewed-by: Gunes Bayir Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- arm_compute/core/Types.h | 75 ++++++++++++++-------- arm_compute/runtime/FunctionDescriptors.h | 8 ++- .../runtime/NEON/functions/NEFullyConnectedLayer.h | 42 ++++++++---- arm_compute/runtime/NEON/functions/NEGEMM.h | 11 +++- src/cpu/operators/CpuFullyConnected.cpp | 43 ++++++++----- src/cpu/operators/CpuFullyConnected.h | 55 ++++++++++------ src/cpu/operators/CpuGemmDirectConv2d.cpp | 23 +++++-- .../operators/internal/CpuGemmAssemblyDispatch.cpp | 55 ++++++++++------ .../NEON/functions/NEFullyConnectedLayer.cpp | 13 +++- src/runtime/NEON/functions/NEGEMM.cpp | 9 ++- 10 files changed, 226 insertions(+), 108 deletions(-) diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index c87c97cb06..66e1c8ab1f 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -774,10 +774,10 @@ public: private: std::pair _stride; - unsigned int _pad_left; - unsigned int _pad_top; - unsigned int _pad_right; - unsigned int _pad_bottom; + unsigned int _pad_left; + unsigned int _pad_top; + unsigned int _pad_right; + unsigned int _pad_bottom; DimensionRoundingType _round_type; }; @@ -919,14 +919,14 @@ public: } private: - std::vector _min_sizes; - std::vector _variances; - float _offset; - bool _flip; - bool _clip; - std::vector _max_sizes; - std::vector _aspect_ratios; - Coordinates2D _img_size; + std::vector _min_sizes; + std::vector _variances; + float _offset; + bool _flip; + bool _clip; + std::vector _max_sizes; + std::vector _aspect_ratios; + Coordinates2D _img_size; std::array _steps; }; @@ -1171,15 +1171,15 @@ public: } private: - unsigned int _max_detections; - unsigned int _max_classes_per_detection; - float _nms_score_threshold; - float _iou_threshold; - unsigned int _num_classes; + unsigned int _max_detections; + unsigned int _max_classes_per_detection; + float _nms_score_threshold; + float _iou_threshold; + unsigned int _num_classes; std::array _scales_values; - bool _use_regular_nms; - unsigned int _detection_per_class; - bool _dequantize_scores; + bool _use_regular_nms; + unsigned int _detection_per_class; + bool _dequantize_scores; }; /** Pooling Layer Information struct*/ @@ -1612,13 +1612,13 @@ public: } private: - float _img_width; - float _img_height; - float _scale; - bool _apply_scale; - bool _correct_transform_coords; + float _img_width; + float _img_height; + float _scale; + bool _apply_scale; + bool _correct_transform_coords; std::array _weights; - float _bbox_xform_clip; + float _bbox_xform_clip; }; /** Activation Layer Information class */ @@ -2053,6 +2053,11 @@ public: { return _weight_format; } + void set_weight_format(arm_compute::WeightFormat weight_format) + { + _weight_format = weight_format; + } + unsigned int kernel_width() const { return _kernel_width; @@ -2495,11 +2500,29 @@ public: return _fixed_format; } + /** Set fixed-format flag + * + * @param[in] fixed_format sets whether or not to use fixed-format kernels + */ + void set_fixed_format(bool fixed_format) + { + _fixed_format = fixed_format; + } + arm_compute::WeightFormat weight_format() const { return _weight_format; } + /** Set weight format to be used + * + * @param[in] weight_format arm_compute::WeightFormat enumeration + */ + void set_weight_format(arm_compute::WeightFormat weight_format) + { + _weight_format = weight_format; + } + private: bool _is_a_reshaped; bool _is_b_reshaped; diff --git a/arm_compute/runtime/FunctionDescriptors.h b/arm_compute/runtime/FunctionDescriptors.h index face8a6fb4..af79820bc3 100644 --- a/arm_compute/runtime/FunctionDescriptors.h +++ b/arm_compute/runtime/FunctionDescriptors.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021 Arm Limited. + * Copyright (c) 2019-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -62,8 +62,9 @@ struct Conv2dInfo const ActivationLayerInfo &act_info, bool enable_fast_math, unsigned int num_groups, - const experimental::PostOpList &post_ops = experimental::PostOpList {}) - : conv_info(conv_info), dilation(dilation), act_info(act_info), enable_fast_math(enable_fast_math), num_groups(num_groups), post_ops(post_ops) + const experimental::PostOpList &post_ops = experimental::PostOpList {}, + const WeightsInfo &weights_info = WeightsInfo()) + : conv_info(conv_info), dilation(dilation), act_info(act_info), enable_fast_math(enable_fast_math), num_groups(num_groups), post_ops(post_ops), weights_info(weights_info) { } @@ -73,6 +74,7 @@ struct Conv2dInfo bool enable_fast_math{ false }; unsigned int num_groups{ 1 }; experimental::PostOpList post_ops{}; + WeightsInfo weights_info{}; }; /** Descriptor used by the 3d Convolution function */ diff --git a/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h b/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h index aa96716d38..2b4f848b22 100644 --- a/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h +++ b/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -112,20 +112,21 @@ public: * |QASYMM8 |QASYMM8 |S32 |QASYMM8 | * |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |QASYMM8_SIGNED | * - * @param[in] input Source tensor. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32. - * @param[in] weights Weights tensor. The weights must be 2 dimensional. - * If this function is called after a Convolution Layer, the (transposed) weights will have as many rows as the product of the first 3 input's dimensions. - * If it is called after another FullyConnected Layer, the (transposed) weights will have as many rows as the input's first dimension. - * Data type supported: Same as @p input. - * @param[in] biases Bias tensor. Can be nullptr. Data type supported: Same as @p weights, S32 if @p weights is QASYMM8/QASYMM8_SIGNED. - * @param[out] output Destination tensor. Its shape should be equal to the output of a matrix multiplication between: - * - The output of im2col on the input and the (transposed) 2D weights, if the function is called after a Convolution Layer - * - The input tensor and the (transposed) 2D weights, if the function is called after another FullyConnected Layer. - * Data type supported: Same as @p input. - * @param[in] fc_info (Optional) Fully connected layer additional info + * @param[in] input Source tensor. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32. + * @param[in] weights Weights tensor. The weights must be 2 dimensional. + * If this function is called after a Convolution Layer, the (transposed) weights will have as many rows as the product of the first 3 input's dimensions. + * If it is called after another FullyConnected Layer, the (transposed) weights will have as many rows as the input's first dimension. + * Data type supported: Same as @p input. + * @param[in] biases Bias tensor. Can be nullptr. Data type supported: Same as @p weights, S32 if @p weights is QASYMM8/QASYMM8_SIGNED. + * @param[out] output Destination tensor. Its shape should be equal to the output of a matrix multiplication between: + * - The output of im2col on the input and the (transposed) 2D weights, if the function is called after a Convolution Layer + * - The input tensor and the (transposed) 2D weights, if the function is called after another FullyConnected Layer. + * Data type supported: Same as @p input. + * @param[in] fc_info (Optional) Fully connected layer additional info + * @param[in] weights_info (Optional) Stores neccessary compute information when weights are already reshaped */ void configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, - FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo()); + FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(), const WeightsInfo &weights_info = WeightsInfo()); /** Static function to check if given info will lead to a valid configuration of @ref NEFullyConnectedLayer * * Similar to @ref NEFullyConnectedLayer @@ -135,6 +136,21 @@ public: static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo()); + /** Static function that queries whether fixed-format kernel exists for a given problem description + * + * @param[out] expected_weight_format Format in which weights should be for found fixed format kernel + * @param[in] input Source tensor + * @param[in] weights Weights tensor. + * @param[in] biases Bias tensor. Can be nullptr. Data type supported: Same as @p weights, S32 if @p weights is QASYMM8/QASYMM8_SIGNED. + * @param[in] output Destination tensor + * @param[in] fc_info Fully connected layer additional info + * @param[in] weights_info Describes weights shape + * + * @return a status + */ + static Status has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *input, const ITensorInfo *weights, + const ITensorInfo *biases, const ITensorInfo *output, const FullyConnectedLayerInfo &fc_info, const WeightsInfo &weights_info); + //Inherited methods override void run() override; void prepare() override; diff --git a/arm_compute/runtime/NEON/functions/NEGEMM.h b/arm_compute/runtime/NEON/functions/NEGEMM.h index ce68a61923..7ce2521148 100644 --- a/arm_compute/runtime/NEON/functions/NEGEMM.h +++ b/arm_compute/runtime/NEON/functions/NEGEMM.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -84,6 +84,15 @@ public: */ static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info = GEMMInfo()); + /** Static function that queries whether there exists fixed-format kernel and if it exists it will return in the first argument in what format + * weights are expected to be reshaped as defined by WeightFormat class. Apart from the first argument the rest of the arguments are the same + * as in @ref NEGEMM::validate() except that all arguments are required. + * + * @return a status + */ + static Status has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, + float alpha, float beta, const GEMMInfo &gemm_info = GEMMInfo()); + // Inherited methods overridden: void run() override; void prepare() override; diff --git a/src/cpu/operators/CpuFullyConnected.cpp b/src/cpu/operators/CpuFullyConnected.cpp index 6d77c614f7..3172644488 100644 --- a/src/cpu/operators/CpuFullyConnected.cpp +++ b/src/cpu/operators/CpuFullyConnected.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -53,7 +53,7 @@ std::pair get_quantized_asymmetric_output_min_max(const { PixelValue type_min{}; PixelValue type_max{}; - std::tie(type_min, type_max) = get_min_max(data_type); + std::tie(type_min, type_max) = get_min_max(data_type); const UniformQuantizationInfo q_unif = q_info.uniform(); if(act_info.enabled()) @@ -162,8 +162,9 @@ CpuFullyConnected::CpuFullyConnected() _is_fc_after_conv(false), _is_quantized_asymmetric(false), _is_prepared(false), - _enable_fast_math(false) - + _enable_fast_math(false), + _fixed_format(false), + _weight_format(arm_compute::WeightFormat::UNSPECIFIED) { } @@ -199,6 +200,8 @@ void CpuFullyConnected::configure_mm(const ITensorInfo *src, const ITensorInfo * GEMMInfo gemm_info(false, false, true /* Reshape weights only for the first run */); gemm_info.set_activation_info(act); gemm_info.set_fast_math(_enable_fast_math); + gemm_info.set_fixed_format(_fixed_format); + gemm_info.set_weight_format(_weight_format); _mm_gemm = std::make_unique(); _mm_gemm->configure(src, weights, biases, dst, 1.f, 1.0f, gemm_info); } @@ -229,7 +232,7 @@ void CpuFullyConnected::configure_fc_fc(const ITensorInfo *src, const ITensorInf } void CpuFullyConnected::configure(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, - FullyConnectedLayerInfo fc_info) + FullyConnectedLayerInfo fc_info, const WeightsInfo &weights_info) { // Perform validate step ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights, dst); @@ -248,6 +251,8 @@ void CpuFullyConnected::configure(const ITensorInfo *src, const ITensorInfo *wei _is_prepared = false; _trans_weights_idx = AuxTensorIdx::Count; _enable_fast_math = fc_info.enable_fast_math; + _fixed_format = weights_info.weight_format() != WeightFormat::UNSPECIFIED; + _weight_format = weights_info.weight_format(); // With the Fully Connected layer we can have 4 different cases: // 1) Convolution layer -> Fully Connected layer without batches @@ -261,9 +266,7 @@ void CpuFullyConnected::configure(const ITensorInfo *src, const ITensorInfo *wei const bool is_batched_fc_layer = dst->dimension(1) > 1; if(is_batched_fc_layer) { - _is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(src->tensor_shape().cbegin() + 3, - src->tensor_shape().cend(), - dst->tensor_shape().cbegin() + 1)); + _is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(src->tensor_shape().cbegin() + 3, src->tensor_shape().cend(), dst->tensor_shape().cbegin() + 1)); } else { @@ -323,12 +326,10 @@ void CpuFullyConnected::configure(const ITensorInfo *src, const ITensorInfo *wei { // Release permuted weights at the end of prepare as they are further transposed by the assembly dispatch // Do not release them if biases are dynamic and data type is quantized, since the weights tensor will be used for biases offset calculation - _aux_mem[TransposedWeights] = MemoryInfo(offset_int_vec(TransposedWeights), (_is_quantized_asymmetric - && biases && !(biases->are_values_constant())) ? - MemoryLifetime::Persistent : - MemoryLifetime::Prepare, + _aux_mem[TransposedWeights] = MemoryInfo(offset_int_vec(TransposedWeights), (_is_quantized_asymmetric && biases + && !(biases->are_values_constant())) ? MemoryLifetime::Persistent : MemoryLifetime::Prepare, _reshaped_weights.total_size()); - _aux_mem[ConvertedWeights] = MemoryInfo(offset_int_vec(ConvertedWeights), MemoryLifetime::Prepare, _converted_weights.total_size()); + _aux_mem[ConvertedWeights] = MemoryInfo(offset_int_vec(ConvertedWeights), MemoryLifetime::Prepare, _converted_weights.total_size()); } else { @@ -338,6 +339,18 @@ void CpuFullyConnected::configure(const ITensorInfo *src, const ITensorInfo *wei _aux_mem[FlattenedSrc] = MemoryInfo(offset_int_vec(FlattenedSrc), MemoryLifetime::Temporary, _flattened_src.total_size()); } +Status CpuFullyConnected::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, + const ITensorInfo *biases, const ITensorInfo *dst, FullyConnectedLayerInfo fc_info, WeightsInfo weights_info) +{ + GEMMInfo gemm_info(false, false, true /* Reshape weights only for the first run */); + gemm_info.set_activation_info(fc_info.activation_info); + gemm_info.set_fast_math(fc_info.enable_fast_math); + gemm_info.set_fixed_format(weights_info.weight_format() != WeightFormat::UNSPECIFIED); + gemm_info.set_weight_format(weights_info.weight_format()); + + return CpuGemm::has_opt_impl(expected_weight_format, src, weights, biases, dst, gemm_info); +} + Status CpuFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, FullyConnectedLayerInfo fc_info) { @@ -384,9 +397,7 @@ Status CpuFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *we if(is_batched_fc_layer) { - is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(src->tensor_shape().cbegin() + 3, - src->tensor_shape().cend(), - dst->tensor_shape().cbegin() + 1)); + is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(src->tensor_shape().cbegin() + 3, src->tensor_shape().cend(), dst->tensor_shape().cbegin() + 1)); } else { diff --git a/src/cpu/operators/CpuFullyConnected.h b/src/cpu/operators/CpuFullyConnected.h index 44fa21f9f8..36511e9d32 100644 --- a/src/cpu/operators/CpuFullyConnected.h +++ b/src/cpu/operators/CpuFullyConnected.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -72,20 +72,21 @@ public: * |QASYMM8 |QASYMM8 |S32 |QASYMM8 | * |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |QASYMM8_SIGNED | * - * @param[in] src Source tensor info. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32. - * @param[in] weights Weights tensor info. The weights must be 2 dimensional. - * If this function is called after a Convolution Layer, the (transposed) weights will have as many rows as the product of the first 3 input's dimensions. - * If it is called after another FullyConnected Layer, the (transposed) weights will have as many rows as the input's first dimension. - * Data type supported: Same as @p src. - * @param[in] biases Bias tensor info. Can be nullptr. Data type supported: Same as @p weights, S32 if @p weights is QASYMM8/QASYMM8_SIGNED. - * @param[out] dst Destination tensor info. Its shape should be equal to the output of a matrix multiplication between: - * - The output of im2col on the input and the (transposed) 2D weights, if the function is called after a Convolution Layer - * - The input tensor and the (transposed) 2D weights, if the function is called after another FullyConnected Layer. - * Data type supported: Same as @p src. - * @param[in] fc_info (Optional) Fully connected layer additional info + * @param[in] src Source tensor info. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32. + * @param[in] weights Weights tensor info. The weights must be 2 dimensional. + * If this function is called after a Convolution Layer, the (transposed) weights will have as many rows as the product of the first 3 input's dimensions. + * If it is called after another FullyConnected Layer, the (transposed) weights will have as many rows as the input's first dimension. + * Data type supported: Same as @p src. + * @param[in] biases Bias tensor info. Can be nullptr. Data type supported: Same as @p weights, S32 if @p weights is QASYMM8/QASYMM8_SIGNED. + * @param[out] dst Destination tensor info. Its shape should be equal to the output of a matrix multiplication between: + * - The output of im2col on the input and the (transposed) 2D weights, if the function is called after a Convolution Layer + * - The input tensor and the (transposed) 2D weights, if the function is called after another FullyConnected Layer. + * Data type supported: Same as @p src. + * @param[in] fc_info (Optional) Fully connected layer additional info + * @param[in] weights_info (Optional) Stores neccessary compute information when weights are already reshaped */ void configure(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, - FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo()); + FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(), const WeightsInfo &weights_info = WeightsInfo()); /** Static function to check if given info will lead to a valid configuration of @ref CpuFullyConnected * * Similar to @ref CpuFullyConnected @@ -95,9 +96,19 @@ public: static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo()); + /** Static function that queries whether there exists fixed-format kernel and if it exists it will return in the first argument in what format + * weights are expected to be reshaped as defined by WeightFormat class. Apart from the first argument the rest of the arguments are the same + * as in @ref CpuFullyConnectedLayer::validate() except that all arguments are required. + * + * @return a status + */ + static Status has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, + const ITensorInfo *biases, const ITensorInfo *dst, + FullyConnectedLayerInfo fc_info, WeightsInfo weights_info); + //Inherited methods override - void run(ITensorPack &tensors) override; - void prepare(ITensorPack &tensors) override; + void run(ITensorPack &tensors) override; + void prepare(ITensorPack &tensors) override; experimental::MemoryRequirements workspace() const override; private: @@ -136,12 +147,14 @@ private: experimental::MemoryRequirements _aux_mem; - bool _needs_weights_conversion; - bool _needs_weights_reshape; - bool _is_fc_after_conv; - bool _is_quantized_asymmetric; - bool _is_prepared; - bool _enable_fast_math; + bool _needs_weights_conversion; + bool _needs_weights_reshape; + bool _is_fc_after_conv; + bool _is_quantized_asymmetric; + bool _is_prepared; + bool _enable_fast_math; + bool _fixed_format; + arm_compute::WeightFormat _weight_format; }; } // namespace cpu } // namespace arm_compute diff --git a/src/cpu/operators/CpuGemmDirectConv2d.cpp b/src/cpu/operators/CpuGemmDirectConv2d.cpp index fd1a042bb4..ee47a17d64 100644 --- a/src/cpu/operators/CpuGemmDirectConv2d.cpp +++ b/src/cpu/operators/CpuGemmDirectConv2d.cpp @@ -57,11 +57,11 @@ GEMMLowpOutputStageInfo calculate_output_stage_metadata(const ITensorInfo *src, ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU }; - PixelValue type_min{}; - PixelValue type_max{}; + PixelValue type_min{}; + PixelValue type_max{}; std::tie(type_min, type_max) = get_min_max(data_type); - int32_t min_activation = type_min.get(); - int32_t max_activation = type_max.get(); + int32_t min_activation = type_min.get(); + int32_t max_activation = type_max.get(); if(supported_acts.count(act.activation()) != 0) { std::tie(min_activation, max_activation) = get_quantized_activation_min_max(act, data_type, uoqinfo); @@ -88,6 +88,8 @@ cpu::AsmGemmInfo init_assembly_metadata(const Conv2dInfo &info, bool is_indirect asm_info.padding_value = 0.f; asm_info.negated_offsets = false; asm_info.fast_mode = info.enable_fast_math; + asm_info.fixed_format = info.weights_info.weight_format() != WeightFormat::UNSPECIFIED; + asm_info.weight_format = info.weights_info.weight_format(); return asm_info; } } // namespace @@ -146,7 +148,9 @@ void CpuGemmDirectConv2d::configure(const ITensorInfo *src, const ITensorInfo *w } else { - _aux_mem[PermutedWeights] = MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Persistent, weights->total_size()); + // We must permute weights if they are WeightFormat::UNSPECIFIED + if(info.weights_info.weight_format() == WeightFormat::UNSPECIFIED) + _aux_mem[PermutedWeights] = MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Persistent, weights->total_size()); } } Status CpuGemmDirectConv2d::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv2dInfo &info) @@ -203,6 +207,13 @@ void CpuGemmDirectConv2d::prepare(ITensorPack &tensors) { if(!_is_prepared) { + // If we are using fixed-format kernel the weights are already reshaped + if(_gemm_asm_func && _gemm_asm_func->isVarWeightsKernel()) + { + _gemm_asm_func->prepare(tensors); + _is_prepared = true; + return; + } const ITensor *weights = tensors.get_const_tensor(ACL_SRC_1); ITensor *weights_aux = utils::cast::polymorphic_cast(tensors.get_tensor(offset_int_vec(PermutedWeights))); ARM_COMPUTE_ERROR_ON_NULLPTR(weights, weights_aux); @@ -224,4 +235,4 @@ experimental::MemoryRequirements CpuGemmDirectConv2d::workspace() const return _aux_mem; } } // namespace cpu -} // namespace arm_compute \ No newline at end of file +} // namespace arm_compute diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp index 45b3232423..df02d649f8 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp @@ -156,8 +156,8 @@ public: const std::vector &multipliers); // Inherited methods overridden: - void run(ITensorPack &tensors) override; - void prepare(ITensorPack &tensors) override; + void run(ITensorPack &tensors) override; + void prepare(ITensorPack &tensors) override; bool is_configured() const override; experimental::MemoryRequirements workspace() const override; bool isVarWeightsKernel() const override @@ -210,12 +210,12 @@ private: /** Indirect buffer */ std::unique_ptr _indirect_arg{}; std::unique_ptr _indirect_buf{}; - std::vector _indirect_pad{}; - arm_gemm::ConvolutionParameters _cp{}; - experimental::MemoryRequirements _aux_mem{ Count }; - bool _B_pretranspose_required{ false }; - bool _is_b_constant{ true }; - bool _is_c_constant{ true }; + std::vector _indirect_pad{}; + arm_gemm::ConvolutionParameters _cp{}; + experimental::MemoryRequirements _aux_mem{ Count }; + bool _B_pretranspose_required{ false }; + bool _is_b_constant{ true }; + bool _is_c_constant{ true }; }; template @@ -493,6 +493,7 @@ void Fallback::run(ITensorPack &tensors) if(!_gemm_kernel_asm->B_is_pretransposed()) { ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput); + multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput); const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format); if(is_fixed_format(wf)) { @@ -501,17 +502,35 @@ void Fallback::run(ITensorPack &tensors) // as a 2D tensor at arm_gemm level, where the rows are // O'/ and the columns are * // H * W * I'. - ITensorInfo *tensor_info = b->info(); - const DataLayout data_layout = tensor_info->data_layout(); - const TensorShape tensor_shape = tensor_info->tensor_shape(); - const int H = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)]; - const int W = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)]; - const int Ip = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)]; - const int interleave_by = arm_compute::interleave_by(wf); - ldb = (interleave_by * H * W * Ip); + ITensorInfo *tensor_info = b->info(); + const DataLayout data_layout = tensor_info->data_layout(); + const TensorShape tensor_shape = tensor_info->tensor_shape(); + const int tensor_height = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)]; + const int tensor_width = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)]; + const int tensor_channels = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)]; + const int interleave_by = arm_compute::interleave_by(wf); + // We need to find a new stride that is distance from the data for one + // set of output channels to the next + if(ldb == tensor_channels && multi_stride_b == tensor_channels * tensor_width) + { + // In this case dimensions that are packed are height, width and channel + // so we need to stride it by interleave_by + ldb = interleave_by * tensor_height * tensor_width * tensor_channels; + } + else if(multi_stride_b == 0 || (ldb == tensor_width && multi_stride_b == tensor_height * tensor_width)) + { + // In this case dimension that is packed is only height + // so we need to stride only height by interleave_by + ldb = interleave_by * tensor_height; + } + else + { + // If dimensions are not packed as above error is thrown + // as at the moment other forms of packing are not supported + ARM_COMPUTE_ERROR("Unsupported packing for fixed format kernel"); + } } - multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput); - in1_ptr = reinterpret_cast(b->buffer() + b->info()->offset_first_element_in_bytes()); + in1_ptr = reinterpret_cast(b->buffer() + b->info()->offset_first_element_in_bytes()); } // If necessary, run pretranspose every time if either weights or biases are non-constant diff --git a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp index 77028d96a2..4f858fb54b 100644 --- a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp +++ b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -61,7 +61,7 @@ NEFullyConnectedLayer::NEFullyConnectedLayer(std::shared_ptr mem } void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, - FullyConnectedLayerInfo fc_info) + FullyConnectedLayerInfo fc_info, const WeightsInfo &weights_info) { // Perform validate step ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output); @@ -76,7 +76,7 @@ void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weigh _impl->original_weights = weights; _impl->is_prepared = false; - _impl->op->configure(input->info(), weights->info(), (biases != nullptr) ? biases->info() : nullptr, output->info(), fc_info); + _impl->op->configure(input->info(), weights->info(), (biases != nullptr) ? biases->info() : nullptr, output->info(), fc_info, weights_info); if(_impl->weights_manager != nullptr) { @@ -88,6 +88,13 @@ void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weigh _impl->workspace = manage_workspace(_impl->aux_mem_req, _impl->memory_group, _impl->run_pack, _impl->run_pack); } +Status NEFullyConnectedLayer::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *input, const ITensorInfo *weights, + const ITensorInfo *biases, const ITensorInfo *output, const FullyConnectedLayerInfo &fc_info, + const WeightsInfo &weights_info) +{ + return cpu::CpuFullyConnected::has_opt_impl(expected_weight_format, input, weights, biases, output, fc_info, weights_info); +} + Status NEFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, FullyConnectedLayerInfo fc_info) { diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp index 58ade9fb3a..0266c48f86 100644 --- a/src/runtime/NEON/functions/NEGEMM.cpp +++ b/src/runtime/NEON/functions/NEGEMM.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -84,6 +84,13 @@ Status NEGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso return cpu::CpuGemm::validate(a, b, c, output, alpha, beta, gemm_info); } +Status NEGEMM::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, + float alpha, float beta, const GEMMInfo &gemm_info) +{ + ARM_COMPUTE_UNUSED(alpha, beta); + return cpu::CpuGemm::has_opt_impl(expected_weight_format, a, b, c, output, gemm_info); +} + void NEGEMM::run() { prepare(); -- cgit v1.2.1