aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGiorgio Arena <giorgio.arena@arm.com>2021-09-24 14:04:27 +0100
committerGiorgio Arena <giorgio.arena@arm.com>2021-09-29 10:31:08 +0000
commit63e0beb9fb9646407d123e830165546e9129e95d (patch)
tree9bfe80e8d853327a82f9f622d89c3b43df0400f4
parentb1ba1e33f2b03b211f561123559c24517c0e5865 (diff)
downloadComputeLibrary-63e0beb9fb9646407d123e830165546e9129e95d.tar.gz
Add support for non-constant weights and biases in CpuFullyConnected
Changing the approach for specifying that weights and biases tensors are non-constant by making it a member of TensorInfo rather than an option of the functions. Resolves: COMPMID-4222, COMPMID-4811 Signed-off-by: Giorgio Arena <giorgio.arena@arm.com> Change-Id: I9b0081ccbcf8271ce029ba6755563d64c59e1d32 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6313 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Pablo Marquez Tello <pablo.tello@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/core/ITensorInfo.h12
-rw-r--r--arm_compute/core/SubTensorInfo.h11
-rw-r--r--arm_compute/core/TensorInfo.h10
-rw-r--r--arm_compute/core/Types.h19
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized_inline.hpp8
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp8
-rw-r--r--src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp9
-rw-r--r--src/core/TensorInfo.cpp8
-rw-r--r--src/cpu/kernels/assembly/gemm_common.hpp3
-rw-r--r--src/cpu/operators/CpuFullyConnected.cpp27
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp37
-rw-r--r--src/gpu/cl/operators/ClFullyConnected.cpp18
-rw-r--r--src/gpu/cl/operators/ClGemm.cpp4
-rw-r--r--tests/validation/NEON/FullyConnectedLayer.cpp15
-rw-r--r--tests/validation/fixtures/FullyConnectedLayerFixture.h128
18 files changed, 275 insertions, 60 deletions
diff --git a/arm_compute/core/ITensorInfo.h b/arm_compute/core/ITensorInfo.h
index 1f5cf30148..6839d697e3 100644
--- a/arm_compute/core/ITensorInfo.h
+++ b/arm_compute/core/ITensorInfo.h
@@ -240,6 +240,11 @@ public:
* @return True if its dynamic else false
*/
virtual bool is_dynamic() const = 0;
+ /** Flag indicating whether the values of the tensor are constant, meaning that they can change on kernel/function execution.
+ *
+ * @return True if values are constant else false
+ */
+ virtual bool are_values_constant() const = 0;
/** Set the flag whether the tensor size can be changed.
*
* @param[in] is_resizable Flag that marks the tensor if it can be changed or not.
@@ -247,6 +252,13 @@ public:
* @return Reference to this ITensorInfo object
*/
virtual ITensorInfo &set_is_resizable(bool is_resizable) = 0;
+ /** Set the flag whether the tensor values can change during kernel/function execution.
+ *
+ * @param[in] are_values_constant Flag that marks the tensor values if they can be changed or not.
+ *
+ * @return Reference to this ITensorInfo object
+ */
+ virtual ITensorInfo &set_are_values_constant(bool are_values_constant) = 0;
/** Valid region of the tensor. All elements in the valid region have defined values, i.e. are not undefined.
*
* @return The valid region.
diff --git a/arm_compute/core/SubTensorInfo.h b/arm_compute/core/SubTensorInfo.h
index 1b2278d99b..54836d0528 100644
--- a/arm_compute/core/SubTensorInfo.h
+++ b/arm_compute/core/SubTensorInfo.h
@@ -196,12 +196,23 @@ public:
ARM_COMPUTE_ERROR_ON(_parent == nullptr);
return _parent->is_dynamic();
}
+ bool are_values_constant() const override
+ {
+ ARM_COMPUTE_ERROR_ON(_parent == nullptr);
+ return _parent->are_values_constant();
+ }
ITensorInfo &set_is_resizable(bool is_resizable) override
{
ARM_COMPUTE_ERROR_ON(_parent == nullptr);
_parent->set_is_resizable(is_resizable);
return *this;
}
+ ITensorInfo &set_are_values_constant(bool are_values_constant) override
+ {
+ ARM_COMPUTE_ERROR_ON(_parent == nullptr);
+ _parent->set_are_values_constant(are_values_constant);
+ return *this;
+ }
ValidRegion valid_region() const override
{
return _valid_region;
diff --git a/arm_compute/core/TensorInfo.h b/arm_compute/core/TensorInfo.h
index a4330849bf..9bc86806fb 100644
--- a/arm_compute/core/TensorInfo.h
+++ b/arm_compute/core/TensorInfo.h
@@ -267,6 +267,10 @@ public:
{
return std::find(std::cbegin(_dims_state), std::cend(_dims_state), get_dynamic_state_value()) != std::cend(_dims_state);
}
+ bool are_values_constant() const override
+ {
+ return _are_values_constant;
+ }
ITensorInfo &set_is_resizable(bool is_resizable) override
{
_is_resizable = is_resizable;
@@ -288,6 +292,11 @@ public:
{
return _data_layout;
}
+ ITensorInfo &set_are_values_constant(bool are_values_constant) override
+ {
+ _are_values_constant = are_values_constant;
+ return *this;
+ }
private:
/** Calculates strides, offset and total size resulting from the specified padding around the XY plane.
@@ -309,6 +318,7 @@ private:
PaddingSize _padding;
QuantizationInfo _quantization_info;
DataLayout _data_layout;
+ bool _are_values_constant;
};
} // namespace arm_compute
#endif /*ARM_COMPUTE_TENSORINFO_H */
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h
index 0acbb3f59e..31199e138b 100644
--- a/arm_compute/core/Types.h
+++ b/arm_compute/core/Types.h
@@ -1557,7 +1557,6 @@ struct FullyConnectedLayerInfo
bool transpose_weights{ true }; /**< Transpose weights if true. */
bool are_weights_reshaped{ false }; /**< Reshape the weights tensor if false. */
bool retain_internal_weights{ false }; /**< Retain internal reshaped weights. */
- bool constant_weights{ true }; /**< If false, weights can vary between runs. */
/* Other parameters */
bool fp_mixed_precision{ false }; /**< Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy. */
@@ -1965,8 +1964,7 @@ public:
_fp_mixed_precision(false),
_broadcast_bias(false),
_pretranspose_B(true),
- _activation_info(),
- _constant_weights(true)
+ _activation_info()
{
}
/** Constructor
@@ -1984,11 +1982,10 @@ public:
* @param[in] fast_math (Optional) Use a data type of shorter width to improve performance
* @param[in] broadcast_bias (Optional) Broadcast the shape of the bias tensor from a vector to a matrix.
* @param[in] activation_info (Optional) Activation to apply after the matrix multiplication
- * @param[in] constant_weights (Optional) Weights have constant values throughout multiple executions
*/
GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 0, bool reinterpret_input_as_3d = false, bool retain_internal_weights = false,
GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(), bool fp_mixed_precision = false, bool fast_math = false, bool broadcast_bias = false,
- const ActivationLayerInfo &activation_info = ActivationLayerInfo(), bool constant_weights = true) noexcept
+ const ActivationLayerInfo &activation_info = ActivationLayerInfo()) noexcept
: _is_a_reshaped(is_a_reshaped),
_is_b_reshaped(is_b_reshaped),
_reshape_b_only_on_first_run(reshape_b_only_on_first_run),
@@ -2000,8 +1997,7 @@ public:
_fp_mixed_precision(fp_mixed_precision),
_broadcast_bias(broadcast_bias),
_pretranspose_B(reshape_b_only_on_first_run),
- _activation_info(activation_info),
- _constant_weights(constant_weights)
+ _activation_info(activation_info)
{
}
/** Flag which specifies if the matrix A has been reshaped
@@ -2126,14 +2122,6 @@ public:
{
_activation_info = activation_info;
}
- /** Flag which specifies if the values of the weights tensor are constant throughout multiple executions or not
- *
- * @return True if the weights tensor is constant
- */
- bool constant_weights() const
- {
- return _constant_weights;
- };
private:
bool _is_a_reshaped;
@@ -2148,7 +2136,6 @@ private:
bool _broadcast_bias;
bool _pretranspose_B;
ActivationLayerInfo _activation_info;
- bool _constant_weights;
};
/** Winograd information */
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
index 5cbdf20798..20c8230148 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
@@ -523,7 +523,7 @@ public:
return size;
}
- void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
+ void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
if (std::is_same<OutputStage, Requantize32>::value) {
_col_bias = reinterpret_cast<int32_t *>(in_buffer);
@@ -534,6 +534,10 @@ public:
compute_col_sums(*qp_ptr, _args._Nsize, _args._Ksize * _args._Ksections, B + (i * B_multi_stride), ldb, _col_bias + (i * _args._Nsize), _args._Ksize * _args._Ksections, i, 0);
}
}
+ }
+
+ void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
+ requantize_bias(in_buffer, B, ldb, B_multi_stride);
// Put the transposed data after the column sums - in non-transposing cases get_col_sum_size() == 0
uintptr_t buffer_int = reinterpret_cast<uintptr_t>(in_buffer);
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp
index c72dca2e96..efb5bd1bb4 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp
@@ -269,12 +269,16 @@ public:
return get_col_sum_size() + (roundup(_Nsize, strategy::out_width()) * roundup(_Ksize, strategy::k_unroll()) * _nmulti * sizeof(Toi));
}
- void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
+ void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
col_bias = reinterpret_cast<int32_t *>(in_buffer);
for (unsigned int i=0; i<_nmulti; i++) {
compute_col_sums(_qp, _Nsize, _Ksize, B + (i * B_multi_stride), ldb, col_bias + (i * _Nsize), _Ksize, i, 0);
}
+ }
+
+ void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
+ requantize_bias(in_buffer, B, ldb, B_multi_stride);
uintptr_t buffer_int = reinterpret_cast<uintptr_t>(in_buffer);
Toi *buffer = reinterpret_cast<Toi *>(buffer_int + get_col_sum_size());
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized_inline.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized_inline.hpp
index 7376b5ffe3..820b54202a 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized_inline.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized_inline.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 Arm Limited.
+ * Copyright (c) 2017-2019,2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -219,12 +219,16 @@ public:
return get_col_sum_size() + (roundup(_Nsize, strategy::out_width()) * roundup(_Ksize, strategy::k_unroll()) * _nmulti * sizeof(Toi));
}
- void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
+ void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
col_bias = reinterpret_cast<int32_t *>(in_buffer);
for (unsigned int i=0; i<_nmulti; i++) {
compute_col_sums(_qp, _Nsize, _Ksize, B + (i * B_multi_stride), ldb, col_bias + (i * _Nsize), _Ksize, i, 0);
}
+ }
+
+ void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
+ requantize_bias(in_buffer, B, ldb, B_multi_stride);
uintptr_t buffer_int = reinterpret_cast<uintptr_t>(in_buffer);
Toi *buffer = reinterpret_cast<Toi *>(buffer_int + get_col_sum_size());
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
index 5639cb4182..c75c320a6b 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
@@ -923,7 +923,7 @@ public:
return (x_size * _Ktotal * _nmulti * sizeof(Toi)) + get_col_sum_size();
}
- void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
+ void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
if (std::is_same<OutputStage, Requantize32>::value) {
col_bias = reinterpret_cast<int32_t *>(in_buffer);
@@ -934,6 +934,10 @@ public:
compute_col_sums(*qp_ptr, _Nsize, _Ksize * _Ksections, B + (i * B_multi_stride), ldb, col_bias + (i * _Nsize), _Ksize * _Ksections, i, 0);
}
}
+ }
+
+ void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
+ requantize_bias(in_buffer, B, ldb, B_multi_stride);
// Put the transposed data after the column sums - in non-transposing cases get_col_sum_size() == 0
uintptr_t buffer_int = reinterpret_cast<uintptr_t>(in_buffer);
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
index d4348beabf..f0b4e5db9e 100644
--- a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
@@ -201,11 +201,11 @@ public:
return _buffer_per_multi * _args._nmulti * sizeof(To) + get_col_sum_size();
}
- void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override {
+ void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
// Column sums go on the front of the pretransposed buffer in requantized cases.
// We could optimize here in case we don't actually need to sum the columns, but this code is only run on setup.
if (std::is_same<OutputStage, Requantize32>::value) {
- col_bias = reinterpret_cast<int32_t *>(buffer);
+ col_bias = reinterpret_cast<int32_t *>(in_buffer);
Requantize32 *qp_ptr = reinterpret_cast<Requantize32 *>(&_os);
@@ -213,6 +213,10 @@ public:
compute_col_sums(*qp_ptr, _args._Nsize, _args._Ksize, B + (i * B_multi_stride), ldb, col_bias + (i * _args._Nsize), _args._Ksize, i, 0);
}
}
+ }
+
+ void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override {
+ requantize_bias(buffer, B, ldb, B_multi_stride);
// The actual transposed buffer goes after the column sums (if any)
uintptr_t buffer_int = reinterpret_cast<uintptr_t>(buffer);
diff --git a/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp b/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp
index 1e2a9acc1d..ce727032e6 100644
--- a/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp
+++ b/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp
@@ -179,13 +179,16 @@ public:
return _subgemm->get_B_pretransposed_array_size() + col_sum_size();
}
+ void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
+ _col_sums = reinterpret_cast<int32_t *>(in_buffer);
+ col_sums_pretransposed(B, ldb, B_multi_stride);
+ }
+
void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override {
uintptr_t buffer_int = reinterpret_cast<uintptr_t>(buffer);
_subgemm->pretranspose_B_array(reinterpret_cast<void *>(buffer_int + col_sum_size()), B, ldb, B_multi_stride);
- _col_sums = reinterpret_cast<int32_t *>(buffer);
-
- col_sums_pretransposed(B, ldb, B_multi_stride);
+ requantize_bias(buffer, B, ldb, B_multi_stride);
}
void set_pretransposed_B_data(void *buffer) override {
diff --git a/src/core/TensorInfo.cpp b/src/core/TensorInfo.cpp
index c471615ee8..e441ddb3a2 100644
--- a/src/core/TensorInfo.cpp
+++ b/src/core/TensorInfo.cpp
@@ -31,11 +31,11 @@
#include <memory>
-using namespace arm_compute;
-
+namespace arm_compute
+{
TensorInfo::TensorInfo()
: _total_size(0), _offset_first_element_in_bytes(0), _strides_in_bytes(), _num_channels(0), _tensor_shape(), _dims_state(), _data_type(DataType::UNKNOWN), _format(Format::UNKNOWN), _is_resizable{ true },
- _valid_region{ Coordinates(), _tensor_shape }, _padding{ 0 }, _quantization_info(), _data_layout(DataLayout::NCHW)
+ _valid_region{ Coordinates(), _tensor_shape }, _padding{ 0 }, _quantization_info(), _data_layout(DataLayout::NCHW), _are_values_constant(true)
{
}
@@ -55,6 +55,7 @@ TensorInfo::TensorInfo(const ITensorInfo &info)
_padding = info.padding();
_quantization_info = info.quantization_info();
_data_layout = info.data_layout();
+ _are_values_constant = info.are_values_constant();
}
TensorInfo::TensorInfo(Format format)
@@ -377,3 +378,4 @@ int32_t TensorInfo::offset_element_in_bytes(const Coordinates &pos) const
return offset;
}
+} // namespace arm_compute
diff --git a/src/cpu/kernels/assembly/gemm_common.hpp b/src/cpu/kernels/assembly/gemm_common.hpp
index 378f1041be..ece9ca5802 100644
--- a/src/cpu/kernels/assembly/gemm_common.hpp
+++ b/src/cpu/kernels/assembly/gemm_common.hpp
@@ -212,6 +212,9 @@ public:
/*** "Pretransposed" interface ***/
+ /* Compute col sums over all columns */
+ virtual void requantize_bias(void *, const To *, const int, const int) {};
+
/* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */
/* Arguments are: output buffer pointer, source pointer, source row stride, source multi stride */
virtual void pretranspose_B_array(void *, const To *, const int, const int) {};
diff --git a/src/cpu/operators/CpuFullyConnected.cpp b/src/cpu/operators/CpuFullyConnected.cpp
index 4133d9e8ca..03c53b001d 100644
--- a/src/cpu/operators/CpuFullyConnected.cpp
+++ b/src/cpu/operators/CpuFullyConnected.cpp
@@ -314,9 +314,14 @@ void CpuFullyConnected::configure(const ITensorInfo *src, const ITensorInfo *wei
if(_aux_mem[Pretranspose].size > 0)
{
- // Release permuted weights at the of prepare as they are further transposed by the assembly dispatch
- _aux_mem[TransposedWeights] = MemoryInfo(offset_int_vec(TransposedWeights), MemoryLifetime::Prepare, _reshaped_weights.total_size());
- _aux_mem[ConvertedWeights] = MemoryInfo(offset_int_vec(ConvertedWeights), MemoryLifetime::Prepare, _converted_weights.total_size());
+ // Release permuted weights at the end of prepare as they are further transposed by the assembly dispatch
+ // Do not release them if biases are dynamic and data type is quantized, since the weights tensor will be used for biases offset calculation
+ _aux_mem[TransposedWeights] = MemoryInfo(offset_int_vec(TransposedWeights), (_is_quantized_asymmetric
+ && biases && !(biases->are_values_constant())) ?
+ MemoryLifetime::Persistent :
+ MemoryLifetime::Prepare,
+ _reshaped_weights.total_size());
+ _aux_mem[ConvertedWeights] = MemoryInfo(offset_int_vec(ConvertedWeights), MemoryLifetime::Prepare, _converted_weights.total_size());
}
else
{
@@ -334,10 +339,9 @@ Status CpuFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *we
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, weights, dst);
ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 2);
- ARM_COMPUTE_RETURN_ERROR_ON(biases != nullptr && biases->num_dimensions() > 1);
ARM_COMPUTE_RETURN_ERROR_ON(fc_info.activation_info.enabled() && is_data_type_quantized(src->data_type()) && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::RELU
&& fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::BOUNDED_RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!fc_info.constant_weights, "Non-constant weights are currently not supported");
+ ARM_COMPUTE_RETURN_ERROR_ON(!weights->are_values_constant() && (!fc_info.are_weights_reshaped || fc_info.transpose_weights));
bool weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true;
bool is_fc_after_conv = true;
@@ -358,6 +362,19 @@ Status CpuFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *we
// Check if we have a fully connected layer with batches
const bool is_batched_fc_layer = dst->dimension(1) > 1;
+ if(biases != nullptr)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
+ if(is_data_type_quantized(src->data_type()))
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::S32);
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, biases);
+ }
+ }
+
if(is_batched_fc_layer)
{
is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(src->tensor_shape().cbegin() + 3,
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 97893b0672..23095d8b84 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -206,6 +206,9 @@ private:
std::vector<TypeInput> _indirect_pad{};
arm_gemm::ConvolutionParameters _cp{};
experimental::MemoryRequirements _aux_mem{ Count };
+ bool _B_pretranspose_required{ false };
+ bool _is_b_constant{ true };
+ bool _is_c_constant{ true };
};
template <typename TypeInput, typename TypeOutput, class OutputStage>
@@ -348,6 +351,10 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *
const OutputStage &os)
{
ARM_COMPUTE_UNUSED(c);
+
+ _is_b_constant = b->are_values_constant();
+ _is_c_constant = c ? c->are_values_constant() : true;
+
arm_gemm::GemmConfig gemm_cfg;
_kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput, OutputStage>(args, os);
if(_kernel_info.method != arm_gemm::GemmMethod::GEMV_BATCHED)
@@ -391,6 +398,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *
const size_t B_pretranspose_size = _gemm_kernel_asm->get_B_pretransposed_array_size();
_pretranspose_info = TensorInfo(TensorShape(B_pretranspose_size), 1, DataType::U8);
_aux_mem[Pretranspose] = MemoryInfo(offset_int_vec(Pretranspose), MemoryLifetime::Persistent, B_pretranspose_size, alignment);
+ _B_pretranspose_required = true;
}
// Handle indirect GEMM convolution
@@ -485,6 +493,35 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
}
+ // If necessary, run pretranspose every time if either weights or biases are non-constant
+ if((b && !_is_b_constant) || (c && !_is_c_constant && c->info()->data_type() == DataType::S32))
+ {
+ if(c && c->info()->data_type() == DataType::S32)
+ {
+ _gemm_kernel_asm->set_quantized_bias(reinterpret_cast<const int32_t *>(c->buffer() + c->info()->offset_first_element_in_bytes()), 0);
+ }
+
+ // Pretranspose B if required
+ if(_B_pretranspose_required)
+ {
+ const int ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
+ const auto b_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
+ const int multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
+
+ CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, true);
+ ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr);
+
+ if(_is_b_constant)
+ {
+ _gemm_kernel_asm->requantize_bias(pretranspose.get()->buffer(), b_ptr, ldb, multi_stride_b);
+ }
+ else
+ {
+ _gemm_kernel_asm->pretranspose_B_array(pretranspose.get()->buffer(), b_ptr, ldb, multi_stride_b);
+ }
+ }
+ }
+
const auto scheduling_hint = scheduling_hint_heuristic(_kernel_info.method, d->info()->data_type());
// Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads
diff --git a/src/gpu/cl/operators/ClFullyConnected.cpp b/src/gpu/cl/operators/ClFullyConnected.cpp
index 165ffe9a47..8afd036e7c 100644
--- a/src/gpu/cl/operators/ClFullyConnected.cpp
+++ b/src/gpu/cl/operators/ClFullyConnected.cpp
@@ -170,8 +170,7 @@ void ClFullyConnected::configure_mm(const CLCompileContext &compile_context, ITe
fc_info.fp_mixed_precision, // fp_mixed_precision
false, // fast_math
true, // broadcast_bias
- fc_info.activation_info, // activation_info
- fc_info.constant_weights); // constant_weights
+ fc_info.activation_info); // activation_info
if(_is_quantized)
{
@@ -335,7 +334,7 @@ Status ClFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *wei
ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 2);
ARM_COMPUTE_RETURN_ERROR_ON(fc_info.activation_info.enabled() && is_data_type_quantized(src->data_type()) && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::RELU
&& fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::BOUNDED_RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU);
- ARM_COMPUTE_RETURN_ERROR_ON(!fc_info.constant_weights && (!fc_info.are_weights_reshaped || fc_info.transpose_weights));
+ ARM_COMPUTE_RETURN_ERROR_ON(!weights->are_values_constant() && (!fc_info.are_weights_reshaped || fc_info.transpose_weights));
bool weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true;
bool is_fc_after_conv = true;
@@ -353,6 +352,19 @@ Status ClFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *wei
const ITensorInfo *src_to_use = src;
const ITensorInfo *weights_to_use = weights;
+ if(biases != nullptr)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
+ if(is_data_type_quantized(src->data_type()))
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::S32);
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, biases);
+ }
+ }
+
// Check if we have a fully connected layer with batches
const bool is_batched_fc_layer = dst->dimension(1) > 1;
if(is_batched_fc_layer)
diff --git a/src/gpu/cl/operators/ClGemm.cpp b/src/gpu/cl/operators/ClGemm.cpp
index e955ae3d65..4cd5237b11 100644
--- a/src/gpu/cl/operators/ClGemm.cpp
+++ b/src/gpu/cl/operators/ClGemm.cpp
@@ -575,7 +575,7 @@ void ClGemm::configure(const CLCompileContext &compile_context, ITensorInfo *a,
// Select GEMMType
_gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery{ CLScheduler::get().target(), a->data_type(), m, n, k, batch_size }, _reshape_b_only_on_first_run,
- gemm_info.constant_weights());
+ b->are_values_constant());
const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
@@ -624,7 +624,7 @@ Status ClGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso
{
CLScheduler::get().target(), a->data_type(), m, n, k, batch_size,
},
- gemm_info.reshape_b_only_on_first_run(), gemm_info.constant_weights());
+ gemm_info.reshape_b_only_on_first_run(), b->are_values_constant());
const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
diff --git a/tests/validation/NEON/FullyConnectedLayer.cpp b/tests/validation/NEON/FullyConnectedLayer.cpp
index 413250f755..5639fb47da 100644
--- a/tests/validation/NEON/FullyConnectedLayer.cpp
+++ b/tests/validation/NEON/FullyConnectedLayer.cpp
@@ -290,6 +290,10 @@ template <typename T>
using NEFullyConnectedLayerFixture = FullyConnectedLayerValidationFixture<Tensor, Accessor, NEFullyConnectedLayer, T>;
template <typename T>
using NEFullyConnectedLayerMixedDataLayoutFixture = FullyConnectedLayerValidationFixture<Tensor, Accessor, NEFullyConnectedLayer, T, true>;
+template <typename T>
+using NEFullyConnectedLayerDynamicWeightsFixture = FullyConnectedWithDynamicWeightsFixture<Tensor, Accessor, NEFullyConnectedLayer, T>;
+template <typename T>
+using NEFullyConnectedLayerDynamicBiasFixture = FullyConnectedWithDynamicBiasFixture<Tensor, Accessor, NEFullyConnectedLayer, T>;
TEST_SUITE(Float)
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
@@ -358,6 +362,11 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEFullyConnectedLayerFixture<float>, framework:
// Validate output
validate(Accessor(_target), _reference, rel_tolerance_f32, 0, abs_tolerance_f32);
}
+FIXTURE_DATA_TEST_CASE(RunDynamicWeights, NEFullyConnectedLayerDynamicWeightsFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallFullyConnectedLayerDataset(),
+ framework::dataset::make("DataType", DataType::F32)),
+ framework::dataset::make("ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))))
+{
+}
TEST_SUITE_END()
TEST_SUITE_END()
@@ -413,6 +422,12 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEFullyConnectedLayerQuantizedFixture<uint8_t>,
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8);
}
+
+FIXTURE_DATA_TEST_CASE(RunDynamicBias, NEFullyConnectedLayerDynamicBiasFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallFullyConnectedLayerDataset(),
+ framework::dataset::make("DataType", DataType::QASYMM8)),
+ framework::dataset::make("ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))))
+{
+}
TEST_SUITE_END()
TEST_SUITE(QASYMM8_SIGNED)
FIXTURE_DATA_TEST_CASE(RunSmall, NEFullyConnectedLayerQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(
diff --git a/tests/validation/fixtures/FullyConnectedLayerFixture.h b/tests/validation/fixtures/FullyConnectedLayerFixture.h
index 7d767642f3..3048c56f6b 100644
--- a/tests/validation/fixtures/FullyConnectedLayerFixture.h
+++ b/tests/validation/fixtures/FullyConnectedLayerFixture.h
@@ -232,7 +232,7 @@ protected:
fill(weights, 1);
fill(bias, 2);
- return reference::activation_layer(reference::fully_connected_layer<T>(src, weights, bias, output_shape), _activation_info, _quantization_info);
+ return reference::activation_layer(reference::fully_connected_layer<T>(src, weights, bias, output_shape, _quantization_info), _activation_info, _quantization_info);
}
TensorType _target{};
@@ -273,7 +273,7 @@ public:
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class FullyConnectedWithDynamicWeightsFixture : public framework::Fixture
+class FullyConnectedWithDynamicTensorsFixture : public framework::Fixture
{
private:
template <typename U>
@@ -289,6 +289,16 @@ private:
std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
library->fill(tensor, distribution, i);
}
+ else if(_data_type == DataType::QASYMM8)
+ {
+ std::uniform_int_distribution<uint8_t> distribution(0, 30);
+ library->fill(tensor, distribution, i);
+ }
+ else if(_data_type == DataType::S32)
+ {
+ std::uniform_int_distribution<int32_t> distribution(-50, 50);
+ library->fill(tensor, distribution, i);
+ }
else
{
library->fill_tensor_uniform(tensor, i);
@@ -324,6 +334,11 @@ private:
constexpr AbsoluteTolerance<float> abs_tolerance_f32(0.0001f);
validate(AccessorType(target), ref, rel_tolerance_f32, 0, abs_tolerance_f32);
}
+ else if(_data_type == DataType::QASYMM8)
+ {
+ constexpr AbsoluteTolerance<uint8_t> tolerance_qasymm8(1);
+ validate(AccessorType(target), ref, tolerance_qasymm8);
+ }
else
{
validate(AccessorType(target), ref);
@@ -331,32 +346,51 @@ private:
}
public:
+ using TDecay = typename std::decay<T>::type;
+ using TBias = typename std::conditional < (std::is_same<TDecay, uint8_t>::value || std::is_same<TDecay, int8_t>::value), int32_t, T >::type;
+
template <typename...>
void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape,
- DataType data_type, ActivationLayerInfo activation_info)
+ DataType data_type, ActivationLayerInfo activation_info, bool constant_weights, bool constant_bias)
{
_data_type = data_type;
+ const bool is_quantized = is_data_type_quantized(data_type);
+
+ const DataType bias_data_type = (is_quantized) ? DataType::S32 : data_type;
+
+ const QuantizationInfo src_qinfo = is_quantized ? QuantizationInfo(0.1f, 10) : QuantizationInfo();
+ const QuantizationInfo weights_qinfo = is_quantized ? QuantizationInfo(0.3f, 20) : QuantizationInfo();
+ const QuantizationInfo dst_qinfo = is_quantized ? QuantizationInfo(0.2f, 5) : QuantizationInfo();
+
// Setup tensor meta-data
- TensorInfo src_info(src_shape, 1, data_type);
+ const TensorInfo src_info(src_shape, 1, data_type, src_qinfo);
_src.allocator()->init(src_info);
- TensorShape tr_weights_shape{ weights_shape[1], weights_shape[0] };
- TensorInfo wei_info(tr_weights_shape, 1, data_type);
+ TensorInfo wei_info(weights_shape, 1, data_type, weights_qinfo);
+ if(!constant_weights)
+ {
+ const TensorShape tr_weights_shape{ weights_shape[1], weights_shape[0] };
+ wei_info.set_tensor_shape(tr_weights_shape);
+ }
+ wei_info.set_are_values_constant(constant_weights);
_weights.allocator()->init(wei_info);
- TensorInfo bias_info(bias_shape, 1, data_type);
+ TensorInfo bias_info(bias_shape, 1, bias_data_type);
+ bias_info.set_are_values_constant(constant_bias);
_bias.allocator()->init(bias_info);
- TensorInfo dst_info(dst_shape, 1, data_type);
+ const TensorInfo dst_info(dst_shape, 1, data_type, dst_qinfo);
_dst.allocator()->init(dst_info);
// Configure FC layer and mark the weights as non constant
FullyConnectedLayerInfo fc_info;
- fc_info.activation_info = activation_info;
- fc_info.are_weights_reshaped = true;
- fc_info.transpose_weights = false;
- fc_info.constant_weights = false;
+ fc_info.activation_info = activation_info;
+ if(!constant_weights)
+ {
+ fc_info.are_weights_reshaped = true;
+ fc_info.transpose_weights = false;
+ }
FunctionType fc;
fc.configure(&_src, &_weights, &_bias, &_dst, fc_info);
@@ -369,29 +403,55 @@ public:
// Run multiple iterations with different inputs
constexpr int num_iterations = 5;
int randomizer_offset = 0;
+
+ // Create reference tensors
+ SimpleTensor<T> src{ src_shape, data_type, 1, src_qinfo };
+ SimpleTensor<T> weights{ weights_shape, data_type, 1, weights_qinfo };
+ SimpleTensor<TBias> bias{ bias_shape, bias_data_type };
+
+ // Fill weights and/or bias if they remain constant
+ if(constant_weights)
+ {
+ fill(AccessorType(_weights), 1);
+ fill(weights, 1);
+ }
+ if(constant_bias)
+ {
+ fill(AccessorType(_bias), 2);
+ fill(bias, 2);
+ }
+
for(int i = 0; i < num_iterations; ++i)
{
// Run target
{
fill(AccessorType(_src), randomizer_offset);
- fill_transposed_weights(_weights, weights_shape, randomizer_offset + 1);
- fill(AccessorType(_bias), randomizer_offset + 2);
+ if(!constant_weights)
+ {
+ fill_transposed_weights(_weights, weights_shape, randomizer_offset + 1);
+ }
+ if(!constant_bias)
+ {
+ fill(AccessorType(_bias), randomizer_offset + 2);
+ }
fc.run();
}
// Run reference and compare
{
- SimpleTensor<T> src{ src_shape, data_type };
- SimpleTensor<T> weights{ weights_shape, data_type };
- SimpleTensor<T> bias{ bias_shape, data_type };
-
// Fill reference
fill(src, randomizer_offset);
- fill(weights, randomizer_offset + 1);
- fill(bias, randomizer_offset + 2);
+ if(!constant_weights)
+ {
+ fill(weights, randomizer_offset + 1);
+ }
+ if(!constant_bias)
+ {
+ fill(bias, randomizer_offset + 2);
+ }
- auto dst = reference::activation_layer(reference::fully_connected_layer<T>(src, weights, bias, dst_shape), activation_info);
+ auto dst = reference::activation_layer(reference::fully_connected_layer<T>(src, weights, bias, dst_shape, dst_qinfo), activation_info, dst_qinfo);
// Validate
validate_with_tolerance(_dst, dst);
@@ -405,6 +465,32 @@ private:
TensorType _src{}, _weights{}, _bias{}, _dst{};
DataType _data_type{ DataType::UNKNOWN };
};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class FullyConnectedWithDynamicWeightsFixture : public FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+ template <typename...>
+ void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape,
+ DataType data_type, ActivationLayerInfo activation_info)
+ {
+ FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, weights_shape, bias_shape,
+ dst_shape, data_type, activation_info, false, true);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class FullyConnectedWithDynamicBiasFixture : public FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+ template <typename...>
+ void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape,
+ DataType data_type, ActivationLayerInfo activation_info)
+ {
+ FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, weights_shape, bias_shape,
+ dst_shape, data_type, activation_info, true, false);
+ }
+};
} // namespace validation
} // namespace test
} // namespace arm_compute