From aed63ee175e0d64c934389e9d1b2edd0cb1a5cdd Mon Sep 17 00:00:00 2001 From: Michele Di Giorgio Date: Mon, 26 Jul 2021 13:18:50 +0100 Subject: Add support for non-constant weights and biases in CpuFullyConnected Changing the approach for specifying that weights and biases tensors are non-constant by making it a member of TensorInfo rather than an option of the functions. Resolves: COMPMID-4222 Change-Id: I96e6f3868f51785c9700a3ef6a1fe7b05747862c Signed-off-by: Michele Di Giorgio Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6162 Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Reviewed-by: Georgios Pinitas --- arm_compute/core/ITensorInfo.h | 12 ++ arm_compute/core/SubTensorInfo.h | 11 ++ arm_compute/core/TensorInfo.h | 10 ++ arm_compute/core/Types.h | 33 ++---- .../NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp | 6 +- .../kernels/arm_gemm/gemm_hybrid_quantized.hpp | 6 +- .../arm_gemm/gemm_hybrid_quantized_inline.hpp | 6 +- .../NEON/kernels/arm_gemm/gemm_interleaved.hpp | 6 +- .../NEON/kernels/arm_gemm/gemv_pretransposed.hpp | 8 +- .../NEON/kernels/arm_gemm/quantize_wrapper.hpp | 9 +- src/core/TensorInfo.cpp | 8 +- src/cpu/kernels/assembly/gemm_common.hpp | 3 + src/cpu/operators/CpuFullyConnected.cpp | 27 ++++- .../operators/internal/CpuGemmAssemblyDispatch.cpp | 31 +++++ src/gpu/cl/operators/ClFullyConnected.cpp | 18 ++- src/gpu/cl/operators/ClGemm.cpp | 4 +- tests/validation/NEON/FullyConnectedLayer.cpp | 15 +++ .../fixtures/FullyConnectedLayerFixture.h | 126 +++++++++++++++++---- utils/TypePrinter.h | 2 +- 19 files changed, 275 insertions(+), 66 deletions(-) diff --git a/arm_compute/core/ITensorInfo.h b/arm_compute/core/ITensorInfo.h index 0171e31086..bc3a6bed8c 100644 --- a/arm_compute/core/ITensorInfo.h +++ b/arm_compute/core/ITensorInfo.h @@ -240,6 +240,11 @@ public: * @return True if its dynamic else false */ virtual bool is_dynamic() const = 0; + /** Flag indicating whether the values of the tensor are constant, meaning that they can change on kernel/function execution. + * + * @return True if values are constant else false + */ + virtual bool are_values_constant() const = 0; /** Set the flag whether the tensor size can be changed. * * @param[in] is_resizable Flag that marks the tensor if it can be changed or not. @@ -247,6 +252,13 @@ public: * @return Reference to this ITensorInfo object */ virtual ITensorInfo &set_is_resizable(bool is_resizable) = 0; + /** Set the flag whether the tensor values can change during kernel/function execution. + * + * @param[in] are_values_constant Flag that marks the tensor values if they can be changed or not. + * + * @return Reference to this ITensorInfo object + */ + virtual ITensorInfo &set_are_values_constant(bool are_values_constant) = 0; /** Valid region of the tensor. All elements in the valid region have defined values, i.e. are not undefined. * * @return The valid region. diff --git a/arm_compute/core/SubTensorInfo.h b/arm_compute/core/SubTensorInfo.h index 1b2278d99b..54836d0528 100644 --- a/arm_compute/core/SubTensorInfo.h +++ b/arm_compute/core/SubTensorInfo.h @@ -196,12 +196,23 @@ public: ARM_COMPUTE_ERROR_ON(_parent == nullptr); return _parent->is_dynamic(); } + bool are_values_constant() const override + { + ARM_COMPUTE_ERROR_ON(_parent == nullptr); + return _parent->are_values_constant(); + } ITensorInfo &set_is_resizable(bool is_resizable) override { ARM_COMPUTE_ERROR_ON(_parent == nullptr); _parent->set_is_resizable(is_resizable); return *this; } + ITensorInfo &set_are_values_constant(bool are_values_constant) override + { + ARM_COMPUTE_ERROR_ON(_parent == nullptr); + _parent->set_are_values_constant(are_values_constant); + return *this; + } ValidRegion valid_region() const override { return _valid_region; diff --git a/arm_compute/core/TensorInfo.h b/arm_compute/core/TensorInfo.h index a4330849bf..9bc86806fb 100644 --- a/arm_compute/core/TensorInfo.h +++ b/arm_compute/core/TensorInfo.h @@ -267,6 +267,10 @@ public: { return std::find(std::cbegin(_dims_state), std::cend(_dims_state), get_dynamic_state_value()) != std::cend(_dims_state); } + bool are_values_constant() const override + { + return _are_values_constant; + } ITensorInfo &set_is_resizable(bool is_resizable) override { _is_resizable = is_resizable; @@ -288,6 +292,11 @@ public: { return _data_layout; } + ITensorInfo &set_are_values_constant(bool are_values_constant) override + { + _are_values_constant = are_values_constant; + return *this; + } private: /** Calculates strides, offset and total size resulting from the specified padding around the XY plane. @@ -309,6 +318,7 @@ private: PaddingSize _padding; QuantizationInfo _quantization_info; DataLayout _data_layout; + bool _are_values_constant; }; } // namespace arm_compute #endif /*ARM_COMPUTE_TENSORINFO_H */ diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index 9c00cbc88c..36b77b8224 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -1544,7 +1544,6 @@ struct FullyConnectedLayerInfo bool transpose_weights{ true }; /**< Transpose weights if true. */ bool are_weights_reshaped{ false }; /**< Reshape the weights tensor if false. */ bool retain_internal_weights{ false }; /**< Retain internal reshaped weights. */ - bool constant_weights{ true }; /**< If false, weights can vary between runs. */ /* Other parameters */ bool fp_mixed_precision{ false }; /**< Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy. */ @@ -1951,9 +1950,8 @@ public: _fast_math(false), _fp_mixed_precision(false), _broadcast_bias(false), - _pretranpose_B(true), - _activation_info(), - _constant_weights(true) + _pretranspose_B(true), + _activation_info() { } /** Constructor @@ -1971,11 +1969,10 @@ public: * @param[in] fast_math (Optional) Use a data type of shorter width to improve performance * @param[in] broadcast_bias (Optional) Broadcast the shape of the bias tensor from a vector to a matrix. * @param[in] activation_info (Optional) Activation to apply after the matrix multiplication - * @param[in] constant_weights (Optional) Weights have constant values throughout multiple executions */ GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 0, bool reinterpret_input_as_3d = false, bool retain_internal_weights = false, GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(), bool fp_mixed_precision = false, bool fast_math = false, bool broadcast_bias = false, - const ActivationLayerInfo &activation_info = ActivationLayerInfo(), bool constant_weights = true) noexcept + const ActivationLayerInfo &activation_info = ActivationLayerInfo()) noexcept : _is_a_reshaped(is_a_reshaped), _is_b_reshaped(is_b_reshaped), _reshape_b_only_on_first_run(reshape_b_only_on_first_run), @@ -1986,9 +1983,8 @@ public: _fast_math(fast_math), _fp_mixed_precision(fp_mixed_precision), _broadcast_bias(broadcast_bias), - _pretranpose_B(reshape_b_only_on_first_run), - _activation_info(activation_info), - _constant_weights(constant_weights) + _pretranspose_B(reshape_b_only_on_first_run), + _activation_info(activation_info) { } /** Flag which specifies if the matrix A has been reshaped @@ -2085,17 +2081,17 @@ public: * * @return True if b should be pre-transposed else false. */ - bool pretranpose_B() const + bool pretranspose_B() const { - return _pretranpose_B; + return _pretranspose_B; }; /** Set pre-transpose b flag * * @param[in] flag Flag to set */ - void set_pretranpose_B(bool flag) + void set_pretranspose_B(bool flag) { - _pretranpose_B = flag; + _pretranspose_B = flag; } /** Activation layer to apply after the matrix multiplication * @@ -2113,14 +2109,6 @@ public: { _activation_info = activation_info; } - /** Flag which specifies if the values of the weights tensor are constant throughout multiple executions or not - * - * @return True if the weights tensor is constant - */ - bool constant_weights() const - { - return _constant_weights; - }; private: bool _is_a_reshaped; @@ -2133,9 +2121,8 @@ private: bool _fast_math; bool _fp_mixed_precision; bool _broadcast_bias; - bool _pretranpose_B; + bool _pretranspose_B; ActivationLayerInfo _activation_info; - bool _constant_weights; }; /** Winograd information */ diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp index 5cbdf20798..20c8230148 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp @@ -523,7 +523,7 @@ public: return size; } - void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override { + void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override { if (std::is_same::value) { _col_bias = reinterpret_cast(in_buffer); @@ -534,6 +534,10 @@ public: compute_col_sums(*qp_ptr, _args._Nsize, _args._Ksize * _args._Ksections, B + (i * B_multi_stride), ldb, _col_bias + (i * _args._Nsize), _args._Ksize * _args._Ksections, i, 0); } } + } + + void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override { + requantize_bias(in_buffer, B, ldb, B_multi_stride); // Put the transposed data after the column sums - in non-transposing cases get_col_sum_size() == 0 uintptr_t buffer_int = reinterpret_cast(in_buffer); diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp index c72dca2e96..efb5bd1bb4 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp @@ -269,12 +269,16 @@ public: return get_col_sum_size() + (roundup(_Nsize, strategy::out_width()) * roundup(_Ksize, strategy::k_unroll()) * _nmulti * sizeof(Toi)); } - void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override { + void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override { col_bias = reinterpret_cast(in_buffer); for (unsigned int i=0; i<_nmulti; i++) { compute_col_sums(_qp, _Nsize, _Ksize, B + (i * B_multi_stride), ldb, col_bias + (i * _Nsize), _Ksize, i, 0); } + } + + void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override { + requantize_bias(in_buffer, B, ldb, B_multi_stride); uintptr_t buffer_int = reinterpret_cast(in_buffer); Toi *buffer = reinterpret_cast(buffer_int + get_col_sum_size()); diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized_inline.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized_inline.hpp index 7376b5ffe3..e84b58dd0f 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized_inline.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized_inline.hpp @@ -219,12 +219,16 @@ public: return get_col_sum_size() + (roundup(_Nsize, strategy::out_width()) * roundup(_Ksize, strategy::k_unroll()) * _nmulti * sizeof(Toi)); } - void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override { + void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override { col_bias = reinterpret_cast(in_buffer); for (unsigned int i=0; i<_nmulti; i++) { compute_col_sums(_qp, _Nsize, _Ksize, B + (i * B_multi_stride), ldb, col_bias + (i * _Nsize), _Ksize, i, 0); } + } + + void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override { + requantize_bias(in_buffer, B, ldb, B_multi_stride); uintptr_t buffer_int = reinterpret_cast(in_buffer); Toi *buffer = reinterpret_cast(buffer_int + get_col_sum_size()); diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp index 5639cb4182..c75c320a6b 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp @@ -923,7 +923,7 @@ public: return (x_size * _Ktotal * _nmulti * sizeof(Toi)) + get_col_sum_size(); } - void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override { + void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override { if (std::is_same::value) { col_bias = reinterpret_cast(in_buffer); @@ -934,6 +934,10 @@ public: compute_col_sums(*qp_ptr, _Nsize, _Ksize * _Ksections, B + (i * B_multi_stride), ldb, col_bias + (i * _Nsize), _Ksize * _Ksections, i, 0); } } + } + + void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override { + requantize_bias(in_buffer, B, ldb, B_multi_stride); // Put the transposed data after the column sums - in non-transposing cases get_col_sum_size() == 0 uintptr_t buffer_int = reinterpret_cast(in_buffer); diff --git a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp index d4348beabf..f0b4e5db9e 100644 --- a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp @@ -201,11 +201,11 @@ public: return _buffer_per_multi * _args._nmulti * sizeof(To) + get_col_sum_size(); } - void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override { + void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override { // Column sums go on the front of the pretransposed buffer in requantized cases. // We could optimize here in case we don't actually need to sum the columns, but this code is only run on setup. if (std::is_same::value) { - col_bias = reinterpret_cast(buffer); + col_bias = reinterpret_cast(in_buffer); Requantize32 *qp_ptr = reinterpret_cast(&_os); @@ -213,6 +213,10 @@ public: compute_col_sums(*qp_ptr, _args._Nsize, _args._Ksize, B + (i * B_multi_stride), ldb, col_bias + (i * _args._Nsize), _args._Ksize, i, 0); } } + } + + void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override { + requantize_bias(buffer, B, ldb, B_multi_stride); // The actual transposed buffer goes after the column sums (if any) uintptr_t buffer_int = reinterpret_cast(buffer); diff --git a/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp b/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp index 1e2a9acc1d..ce727032e6 100644 --- a/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp +++ b/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp @@ -179,13 +179,16 @@ public: return _subgemm->get_B_pretransposed_array_size() + col_sum_size(); } + void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override { + _col_sums = reinterpret_cast(in_buffer); + col_sums_pretransposed(B, ldb, B_multi_stride); + } + void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override { uintptr_t buffer_int = reinterpret_cast(buffer); _subgemm->pretranspose_B_array(reinterpret_cast(buffer_int + col_sum_size()), B, ldb, B_multi_stride); - _col_sums = reinterpret_cast(buffer); - - col_sums_pretransposed(B, ldb, B_multi_stride); + requantize_bias(buffer, B, ldb, B_multi_stride); } void set_pretransposed_B_data(void *buffer) override { diff --git a/src/core/TensorInfo.cpp b/src/core/TensorInfo.cpp index c471615ee8..e441ddb3a2 100644 --- a/src/core/TensorInfo.cpp +++ b/src/core/TensorInfo.cpp @@ -31,11 +31,11 @@ #include -using namespace arm_compute; - +namespace arm_compute +{ TensorInfo::TensorInfo() : _total_size(0), _offset_first_element_in_bytes(0), _strides_in_bytes(), _num_channels(0), _tensor_shape(), _dims_state(), _data_type(DataType::UNKNOWN), _format(Format::UNKNOWN), _is_resizable{ true }, - _valid_region{ Coordinates(), _tensor_shape }, _padding{ 0 }, _quantization_info(), _data_layout(DataLayout::NCHW) + _valid_region{ Coordinates(), _tensor_shape }, _padding{ 0 }, _quantization_info(), _data_layout(DataLayout::NCHW), _are_values_constant(true) { } @@ -55,6 +55,7 @@ TensorInfo::TensorInfo(const ITensorInfo &info) _padding = info.padding(); _quantization_info = info.quantization_info(); _data_layout = info.data_layout(); + _are_values_constant = info.are_values_constant(); } TensorInfo::TensorInfo(Format format) @@ -377,3 +378,4 @@ int32_t TensorInfo::offset_element_in_bytes(const Coordinates &pos) const return offset; } +} // namespace arm_compute diff --git a/src/cpu/kernels/assembly/gemm_common.hpp b/src/cpu/kernels/assembly/gemm_common.hpp index 378f1041be..ece9ca5802 100644 --- a/src/cpu/kernels/assembly/gemm_common.hpp +++ b/src/cpu/kernels/assembly/gemm_common.hpp @@ -212,6 +212,9 @@ public: /*** "Pretransposed" interface ***/ + /* Compute col sums over all columns */ + virtual void requantize_bias(void *, const To *, const int, const int) {}; + /* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */ /* Arguments are: output buffer pointer, source pointer, source row stride, source multi stride */ virtual void pretranspose_B_array(void *, const To *, const int, const int) {}; diff --git a/src/cpu/operators/CpuFullyConnected.cpp b/src/cpu/operators/CpuFullyConnected.cpp index cafb3484b6..d952724cdc 100644 --- a/src/cpu/operators/CpuFullyConnected.cpp +++ b/src/cpu/operators/CpuFullyConnected.cpp @@ -312,9 +312,14 @@ void CpuFullyConnected::configure(const ITensorInfo *src, const ITensorInfo *wei if(_aux_mem[Pretranspose].size > 0) { - // Release permuted weights at the of prepare as they are further transposed by the assembly dispatch - _aux_mem[TransposedWeights] = MemoryInfo(offset_int_vec(TransposedWeights), MemoryLifetime::Prepare, _reshaped_weights.total_size()); - _aux_mem[ConvertedWeights] = MemoryInfo(offset_int_vec(ConvertedWeights), MemoryLifetime::Prepare, _converted_weights.total_size()); + // Release permuted weights at the end of prepare as they are further transposed by the assembly dispatch + // Do not release them if biases are dynamic and data type is quantized, since the weights tensor will be used for biases offset calculation + _aux_mem[TransposedWeights] = MemoryInfo(offset_int_vec(TransposedWeights), (_is_quantized_asymmetric + && !(biases->are_values_constant())) ? + MemoryLifetime::Persistent : + MemoryLifetime::Prepare, + _reshaped_weights.total_size()); + _aux_mem[ConvertedWeights] = MemoryInfo(offset_int_vec(ConvertedWeights), MemoryLifetime::Prepare, _converted_weights.total_size()); } else { @@ -332,10 +337,9 @@ Status CpuFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *we ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, weights, dst); ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 2); - ARM_COMPUTE_RETURN_ERROR_ON(biases != nullptr && biases->num_dimensions() > 1); ARM_COMPUTE_RETURN_ERROR_ON(fc_info.activation_info.enabled() && is_data_type_quantized(src->data_type()) && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::BOUNDED_RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(!fc_info.constant_weights, "Non-constant weights are currently not supported"); + ARM_COMPUTE_RETURN_ERROR_ON(!weights->are_values_constant() && (!fc_info.are_weights_reshaped || fc_info.transpose_weights)); bool weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true; bool is_fc_after_conv = true; @@ -356,6 +360,19 @@ Status CpuFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *we // Check if we have a fully connected layer with batches const bool is_batched_fc_layer = dst->dimension(1) > 1; + if(biases != nullptr) + { + ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1); + if(is_data_type_quantized(src->data_type())) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::S32); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, biases); + } + } + if(is_batched_fc_layer) { is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(src->tensor_shape().cbegin() + 3, diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp index 97893b0672..1dd6286dbf 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp @@ -206,6 +206,7 @@ private: std::vector _indirect_pad{}; arm_gemm::ConvolutionParameters _cp{}; experimental::MemoryRequirements _aux_mem{ Count }; + bool _B_pretranspose_required{ false }; }; template @@ -391,6 +392,7 @@ void Fallback::configure(const ITensorInfo * const size_t B_pretranspose_size = _gemm_kernel_asm->get_B_pretransposed_array_size(); _pretranspose_info = TensorInfo(TensorShape(B_pretranspose_size), 1, DataType::U8); _aux_mem[Pretranspose] = MemoryInfo(offset_int_vec(Pretranspose), MemoryLifetime::Persistent, B_pretranspose_size, alignment); + _B_pretranspose_required = true; } // Handle indirect GEMM convolution @@ -485,6 +487,35 @@ void Fallback::run(ITensorPack &tensors) in1_ptr = reinterpret_cast(b->buffer() + b->info()->offset_first_element_in_bytes()); } + // If necessary, run pretranspose every time if either weights or biases are non-constant + if((b && !b->info()->are_values_constant()) || (c && !c->info()->are_values_constant() && c->info()->data_type() == DataType::S32)) + { + if(c && c->info()->data_type() == DataType::S32) + { + _gemm_kernel_asm->set_quantized_bias(reinterpret_cast(c->buffer() + c->info()->offset_first_element_in_bytes()), 0); + } + + // Pretranspose B if required + if(_B_pretranspose_required) + { + const int ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput); + const auto b_ptr = reinterpret_cast(b->buffer() + b->info()->offset_first_element_in_bytes()); + const int multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput); + + CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, true); + ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr); + + if(b->info()->are_values_constant()) + { + _gemm_kernel_asm->requantize_bias(pretranspose.get()->buffer(), b_ptr, ldb, multi_stride_b); + } + else + { + _gemm_kernel_asm->pretranspose_B_array(pretranspose.get()->buffer(), b_ptr, ldb, multi_stride_b); + } + } + } + const auto scheduling_hint = scheduling_hint_heuristic(_kernel_info.method, d->info()->data_type()); // Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads diff --git a/src/gpu/cl/operators/ClFullyConnected.cpp b/src/gpu/cl/operators/ClFullyConnected.cpp index 8b7e336c9f..bd2fddad0b 100644 --- a/src/gpu/cl/operators/ClFullyConnected.cpp +++ b/src/gpu/cl/operators/ClFullyConnected.cpp @@ -169,8 +169,7 @@ void ClFullyConnected::configure_mm(const CLCompileContext &compile_context, ITe fc_info.fp_mixed_precision, // fp_mixed_precision false, // fast_math true, // broadcast_bias - fc_info.activation_info, // activation_info - fc_info.constant_weights); // constant_weights + fc_info.activation_info); // activation_info if(_is_quantized) { @@ -333,7 +332,7 @@ Status ClFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *wei ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 2); ARM_COMPUTE_RETURN_ERROR_ON(fc_info.activation_info.enabled() && is_data_type_quantized(src->data_type()) && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::BOUNDED_RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU); - ARM_COMPUTE_RETURN_ERROR_ON(!fc_info.constant_weights && (!fc_info.are_weights_reshaped || fc_info.transpose_weights)); + ARM_COMPUTE_RETURN_ERROR_ON(!weights->are_values_constant() && (!fc_info.are_weights_reshaped || fc_info.transpose_weights)); bool weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true; bool is_fc_after_conv = true; @@ -351,6 +350,19 @@ Status ClFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *wei const ITensorInfo *src_to_use = src; const ITensorInfo *weights_to_use = weights; + if(biases != nullptr) + { + ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1); + if(is_data_type_quantized(src->data_type())) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::S32); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, biases); + } + } + // Check if we have a fully connected layer with batches const bool is_batched_fc_layer = dst->dimension(1) > 1; if(is_batched_fc_layer) diff --git a/src/gpu/cl/operators/ClGemm.cpp b/src/gpu/cl/operators/ClGemm.cpp index 625c057cf4..292f531dc4 100644 --- a/src/gpu/cl/operators/ClGemm.cpp +++ b/src/gpu/cl/operators/ClGemm.cpp @@ -574,7 +574,7 @@ void ClGemm::configure(const CLCompileContext &compile_context, ITensorInfo *a, // Select GEMMType _gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery{ CLScheduler::get().target(), a->data_type(), m, n, k, batch_size }, _reshape_b_only_on_first_run, - gemm_info.constant_weights()); + b->are_values_constant()); const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr); @@ -623,7 +623,7 @@ Status ClGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso { CLScheduler::get().target(), a->data_type(), m, n, k, batch_size, }, - gemm_info.reshape_b_only_on_first_run(), gemm_info.constant_weights()); + gemm_info.reshape_b_only_on_first_run(), b->are_values_constant()); const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr); diff --git a/tests/validation/NEON/FullyConnectedLayer.cpp b/tests/validation/NEON/FullyConnectedLayer.cpp index 413250f755..5639fb47da 100644 --- a/tests/validation/NEON/FullyConnectedLayer.cpp +++ b/tests/validation/NEON/FullyConnectedLayer.cpp @@ -290,6 +290,10 @@ template using NEFullyConnectedLayerFixture = FullyConnectedLayerValidationFixture; template using NEFullyConnectedLayerMixedDataLayoutFixture = FullyConnectedLayerValidationFixture; +template +using NEFullyConnectedLayerDynamicWeightsFixture = FullyConnectedWithDynamicWeightsFixture; +template +using NEFullyConnectedLayerDynamicBiasFixture = FullyConnectedWithDynamicBiasFixture; TEST_SUITE(Float) #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -358,6 +362,11 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEFullyConnectedLayerFixture, framework: // Validate output validate(Accessor(_target), _reference, rel_tolerance_f32, 0, abs_tolerance_f32); } +FIXTURE_DATA_TEST_CASE(RunDynamicWeights, NEFullyConnectedLayerDynamicWeightsFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallFullyConnectedLayerDataset(), + framework::dataset::make("DataType", DataType::F32)), + framework::dataset::make("ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)))) +{ +} TEST_SUITE_END() TEST_SUITE_END() @@ -413,6 +422,12 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEFullyConnectedLayerQuantizedFixture, // Validate output validate(Accessor(_target), _reference, tolerance_qasymm8); } + +FIXTURE_DATA_TEST_CASE(RunDynamicBias, NEFullyConnectedLayerDynamicBiasFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallFullyConnectedLayerDataset(), + framework::dataset::make("DataType", DataType::QASYMM8)), + framework::dataset::make("ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)))) +{ +} TEST_SUITE_END() TEST_SUITE(QASYMM8_SIGNED) FIXTURE_DATA_TEST_CASE(RunSmall, NEFullyConnectedLayerQuantizedFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine( diff --git a/tests/validation/fixtures/FullyConnectedLayerFixture.h b/tests/validation/fixtures/FullyConnectedLayerFixture.h index 7d767642f3..ccd9182ae9 100644 --- a/tests/validation/fixtures/FullyConnectedLayerFixture.h +++ b/tests/validation/fixtures/FullyConnectedLayerFixture.h @@ -273,7 +273,7 @@ public: }; template -class FullyConnectedWithDynamicWeightsFixture : public framework::Fixture +class FullyConnectedWithDynamicTensorsFixture : public framework::Fixture { private: template @@ -289,6 +289,16 @@ private: std::uniform_real_distribution distribution(-1.0f, 1.0f); library->fill(tensor, distribution, i); } + else if(_data_type == DataType::QASYMM8) + { + std::uniform_int_distribution distribution(0, 30); + library->fill(tensor, distribution, i); + } + else if(_data_type == DataType::S32) + { + std::uniform_int_distribution distribution(-50, 50); + library->fill(tensor, distribution, i); + } else { library->fill_tensor_uniform(tensor, i); @@ -324,6 +334,11 @@ private: constexpr AbsoluteTolerance abs_tolerance_f32(0.0001f); validate(AccessorType(target), ref, rel_tolerance_f32, 0, abs_tolerance_f32); } + else if(_data_type == DataType::QASYMM8) + { + constexpr AbsoluteTolerance tolerance_qasymm8(1); + validate(AccessorType(target), ref, tolerance_qasymm8); + } else { validate(AccessorType(target), ref); @@ -331,32 +346,51 @@ private: } public: + using TDecay = typename std::decay::type; + using TBias = typename std::conditional < (std::is_same::value || std::is_same::value), int32_t, T >::type; + template void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape, - DataType data_type, ActivationLayerInfo activation_info) + DataType data_type, ActivationLayerInfo activation_info, bool constant_weights, bool constant_bias) { _data_type = data_type; + const bool is_quantized = is_data_type_quantized(data_type); + + const DataType bias_data_type = (is_quantized) ? DataType::S32 : data_type; + + const QuantizationInfo src_qinfo = is_quantized ? QuantizationInfo(0.1f, 10) : QuantizationInfo(); + const QuantizationInfo weights_qinfo = is_quantized ? QuantizationInfo(0.3f, 20) : QuantizationInfo(); + const QuantizationInfo dst_qinfo = is_quantized ? QuantizationInfo(0.2f, 5) : QuantizationInfo(); + // Setup tensor meta-data - TensorInfo src_info(src_shape, 1, data_type); + const TensorInfo src_info(src_shape, 1, data_type, src_qinfo); _src.allocator()->init(src_info); - TensorShape tr_weights_shape{ weights_shape[1], weights_shape[0] }; - TensorInfo wei_info(tr_weights_shape, 1, data_type); + TensorInfo wei_info(weights_shape, 1, data_type, weights_qinfo); + if(!constant_weights) + { + const TensorShape tr_weights_shape{ weights_shape[1], weights_shape[0] }; + wei_info.set_tensor_shape(tr_weights_shape); + } + wei_info.set_are_values_constant(constant_weights); _weights.allocator()->init(wei_info); - TensorInfo bias_info(bias_shape, 1, data_type); + TensorInfo bias_info(bias_shape, 1, bias_data_type); + bias_info.set_are_values_constant(constant_bias); _bias.allocator()->init(bias_info); - TensorInfo dst_info(dst_shape, 1, data_type); + const TensorInfo dst_info(dst_shape, 1, data_type, dst_qinfo); _dst.allocator()->init(dst_info); // Configure FC layer and mark the weights as non constant FullyConnectedLayerInfo fc_info; - fc_info.activation_info = activation_info; - fc_info.are_weights_reshaped = true; - fc_info.transpose_weights = false; - fc_info.constant_weights = false; + fc_info.activation_info = activation_info; + if(!constant_weights) + { + fc_info.are_weights_reshaped = true; + fc_info.transpose_weights = false; + } FunctionType fc; fc.configure(&_src, &_weights, &_bias, &_dst, fc_info); @@ -369,29 +403,55 @@ public: // Run multiple iterations with different inputs constexpr int num_iterations = 5; int randomizer_offset = 0; + + // Create reference tensors + SimpleTensor src{ src_shape, data_type, 1, src_qinfo }; + SimpleTensor weights{ weights_shape, data_type, 1, weights_qinfo }; + SimpleTensor bias{ bias_shape, bias_data_type }; + + // Fill weights and/or bias if they remain constant + if(constant_weights) + { + fill(AccessorType(_weights), 1); + fill(weights, 1); + } + if(constant_bias) + { + fill(AccessorType(_bias), 2); + fill(bias, 2); + } + for(int i = 0; i < num_iterations; ++i) { // Run target { fill(AccessorType(_src), randomizer_offset); - fill_transposed_weights(_weights, weights_shape, randomizer_offset + 1); - fill(AccessorType(_bias), randomizer_offset + 2); + if(!constant_weights) + { + fill_transposed_weights(_weights, weights_shape, randomizer_offset + 1); + } + if(!constant_bias) + { + fill(AccessorType(_bias), randomizer_offset + 2); + } fc.run(); } // Run reference and compare { - SimpleTensor src{ src_shape, data_type }; - SimpleTensor weights{ weights_shape, data_type }; - SimpleTensor bias{ bias_shape, data_type }; - // Fill reference fill(src, randomizer_offset); - fill(weights, randomizer_offset + 1); - fill(bias, randomizer_offset + 2); + if(!constant_weights) + { + fill(weights, randomizer_offset + 1); + } + if(!constant_bias) + { + fill(bias, randomizer_offset + 2); + } - auto dst = reference::activation_layer(reference::fully_connected_layer(src, weights, bias, dst_shape), activation_info); + auto dst = reference::activation_layer(reference::fully_connected_layer(src, weights, bias, dst_shape), activation_info, dst_qinfo); // Validate validate_with_tolerance(_dst, dst); @@ -405,6 +465,32 @@ private: TensorType _src{}, _weights{}, _bias{}, _dst{}; DataType _data_type{ DataType::UNKNOWN }; }; + +template +class FullyConnectedWithDynamicWeightsFixture : public FullyConnectedWithDynamicTensorsFixture +{ +public: + template + void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape, + DataType data_type, ActivationLayerInfo activation_info) + { + FullyConnectedWithDynamicTensorsFixture::setup(src_shape, weights_shape, bias_shape, + dst_shape, data_type, activation_info, false, true); + } +}; + +template +class FullyConnectedWithDynamicBiasFixture : public FullyConnectedWithDynamicTensorsFixture +{ +public: + template + void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape, + DataType data_type, ActivationLayerInfo activation_info) + { + FullyConnectedWithDynamicTensorsFixture::setup(src_shape, weights_shape, bias_shape, + dst_shape, data_type, activation_info, true, false); + } +}; } // namespace validation } // namespace test } // namespace arm_compute diff --git a/utils/TypePrinter.h b/utils/TypePrinter.h index 58ddb3f7bf..248c973b68 100644 --- a/utils/TypePrinter.h +++ b/utils/TypePrinter.h @@ -1158,7 +1158,7 @@ inline ::std::ostream &operator<<(::std::ostream &os, const GEMMInfo &info) os << "retain_internal_weights=" << info.retain_internal_weights() << ","; os << "fp_mixed_precision=" << info.fp_mixed_precision() << ","; os << "broadcast_bias=" << info.broadcast_bias() << ","; - os << "pretranpose_B=" << info.pretranpose_B() << ","; + os << "pretranspose_B=" << info.pretranspose_B() << ","; return os; } -- cgit v1.2.1