From a2bb80ea7111509c24caad8629533089decef430 Mon Sep 17 00:00:00 2001 From: Mohammed Suhail Munshi Date: Mon, 19 Jun 2023 14:57:57 +0100 Subject: Use MatMul in fully connected layer with dynamic weights when supported - Use MatMul kernels in FC layer when using dynamic weights without broadcasting or bias. - Fix minor typo in IClMatMulNativeKernelConfig.h Partially Resolves : [COMPMID-6193] Signed-off-by: Mohammed Suhail Munshi Change-Id: Id494062b5b4f4e75ff9714c202dde941955afa52 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9797 Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Reviewed-by: Gunes Bayir Benchmark: Arm Jenkins --- src/core/CL/cl_kernels/common/mat_mul_quantized.cl | 10 +- src/gpu/cl/kernels/ClMatMulLowpNativeKernel.cpp | 7 +- src/gpu/cl/operators/ClFullyConnected.cpp | 352 ++++++++++++++------- src/gpu/cl/operators/ClFullyConnected.h | 19 +- .../matmul_native/IClMatMulNativeKernelConfig.h | 2 +- tests/validation/CL/FullyConnectedLayer.cpp | 38 ++- .../fixtures/FullyConnectedLayerFixture.h | 63 ++-- 7 files changed, 330 insertions(+), 161 deletions(-) diff --git a/src/core/CL/cl_kernels/common/mat_mul_quantized.cl b/src/core/CL/cl_kernels/common/mat_mul_quantized.cl index bd415bb4a7..8cf857dd84 100644 --- a/src/core/CL/cl_kernels/common/mat_mul_quantized.cl +++ b/src/core/CL/cl_kernels/common/mat_mul_quantized.cl @@ -21,9 +21,9 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ +#include "activation_float_helpers.h" #include "helpers.h" #include "tile_helpers.h" -#include "activation_float_helpers.h" #if defined(MAT_MUL_NATIVE_QUANTIZED_NT_NT) /** This OpenCL kernel performs the batch matrix multiplication (BatchMatMul): LHS non-transposed, RHS non-transposed - buffer only @@ -189,7 +189,7 @@ __kernel void mat_mul_native_quantized_nt_nt( { LOOP_UNROLLING(int, j, 0, 1, N0, { - acc[i].s[j] += ((int)RHS_OFFSET) * a_sum[0].s[i] + ((int)(LHS_OFFSET)) * b_sum[0].s[j]; + acc[i].s[j] -= ((int)RHS_OFFSET) * a_sum[0].s[i] + ((int)(LHS_OFFSET)) * b_sum[0].s[j]; }) }) @@ -368,7 +368,7 @@ __kernel void mat_mul_native_quantized_nt_t( { LOOP_UNROLLING(int, j, 0, 1, N0, { - acc[i].s[j] += ((int)(RHS_OFFSET)) * a_sum[0].s[i] + ((int)(LHS_OFFSET)) * b_sum[0].s[j]; + acc[i].s[j] -= ((int)(RHS_OFFSET)) * a_sum[0].s[i] + ((int)(LHS_OFFSET)) * b_sum[0].s[j]; }) }) @@ -549,7 +549,7 @@ __kernel void mat_mul_native_quantized_t_nt( { LOOP_UNROLLING(int, j, 0, 1, N0, { - acc[i].s[j] += ((int)(RHS_OFFSET)) * a_sum[0].s[i] + ((int)(LHS_OFFSET)) * b_sum[0].s[j]; + acc[i].s[j] -= ((int)(RHS_OFFSET)) * a_sum[0].s[i] + ((int)(LHS_OFFSET)) * b_sum[0].s[j]; }) }) @@ -734,7 +734,7 @@ __kernel void mat_mul_native_quantized_t_t( { LOOP_UNROLLING(int, j, 0, 1, N0, { - acc[i].s[j] += ((int)RHS_OFFSET) * a_sum[0].s[i] + ((int)(LHS_OFFSET)) * b_sum[0].s[j]; + acc[i].s[j] -= ((int)RHS_OFFSET) * a_sum[0].s[i] + ((int)(LHS_OFFSET)) * b_sum[0].s[j]; }) }) diff --git a/src/gpu/cl/kernels/ClMatMulLowpNativeKernel.cpp b/src/gpu/cl/kernels/ClMatMulLowpNativeKernel.cpp index 9bbec908a3..38d78c618b 100644 --- a/src/gpu/cl/kernels/ClMatMulLowpNativeKernel.cpp +++ b/src/gpu/cl/kernels/ClMatMulLowpNativeKernel.cpp @@ -164,9 +164,10 @@ void ClMatMulLowpNativeKernel::configure(const ClCompileContext &compile_context build_opts.add_option("-DDST_MULTIPLIER=" + support::cpp11::to_string(output_multiplier)); build_opts.add_option("-DDST_SHIFT=" + support::cpp11::to_string(output_shift)); - build_opts.add_option("-DLHS_OFFSET=" + support::cpp11::to_string(-lqinfo.offset)); // Note this is passed as negative to maintain similarity with CLDirectConv2D - build_opts.add_option("-DRHS_OFFSET=" + support::cpp11::to_string(-rqinfo.offset)); // Note this is passed as negative to maintain similarity with CLDirectConv2D - build_opts.add_option("-DDST_OFFSET=" + support::cpp11::to_string(dqinfo.offset)); // Passed as positive (unlike the above two) + // Note : Offset is not negated, unlike gemmlowp kernels + build_opts.add_option("-DLHS_OFFSET=" + support::cpp11::to_string(lqinfo.offset)); + build_opts.add_option("-DRHS_OFFSET=" + support::cpp11::to_string(rqinfo.offset)); + build_opts.add_option("-DDST_OFFSET=" + support::cpp11::to_string(dqinfo.offset)); // Passed as positive (unlike the above two) build_opts.add_option(("-DA_VAL=" + float_to_string_with_full_precision(act_info.a()))); build_opts.add_option(("-DB_VAL=" + float_to_string_with_full_precision(act_info.b()))); diff --git a/src/gpu/cl/operators/ClFullyConnected.cpp b/src/gpu/cl/operators/ClFullyConnected.cpp index b289cc0104..c62e4b531f 100644 --- a/src/gpu/cl/operators/ClFullyConnected.cpp +++ b/src/gpu/cl/operators/ClFullyConnected.cpp @@ -38,6 +38,12 @@ #include "src/gpu/cl/operators/ClTranspose.h" #include "src/gpu/cl/utils/ClAuxTensorHandler.h" +#include "src/gpu/cl/operators/ClMatMul.h" +#include "utils/TypePrinter.h" + +#include "src/runtime/heuristics/matmul_native/ClMatMulNativeKernelConfig.h" +#include "src/runtime/heuristics/matmul_native/IClMatMulNativeKernelConfig.h" + #include "src/common/utils/Log.h" #include "support/Cast.h" @@ -52,6 +58,12 @@ using namespace arm_compute::misc::shape_calculator; namespace { +// Function to calculate batched tensor shape in format [M, 1, B0, B1 ..] which is the format matmul expects +inline TensorShape get_reshaped_matmul_tensor(const TensorShape &src) +{ + return TensorShape(src.x(), 1, src.y(), src.collapsed_from(2).z()); // Return value optimisation +} + Status construct_gemmlowp_output_stage(const ITensorInfo &src, const ITensorInfo &weights, const ITensorInfo &dst, GEMMLowpOutputStageInfo &gemmlowp_output_stage, ActivationLayerInfo activation_info) { @@ -101,41 +113,61 @@ Status construct_gemmlowp_output_stage(const ITensorInfo &src, const ITensorInfo Status validate_mm(const ITensorInfo &src, const ITensorInfo &weights, const ITensorInfo *bias, const ITensorInfo &dst, const FullyConnectedLayerInfo &fc_info) { - GEMMLowpOutputStageInfo gemmlowp_output_stage; - ARM_COMPUTE_RETURN_ON_ERROR(construct_gemmlowp_output_stage(src, weights, dst, gemmlowp_output_stage, fc_info.activation_info)); - - const GEMMInfo &gemm_info = GEMMInfo(false, // is_a_reshaped - false, // is_b_reshaped - true, // reshape_b_only_on_first_run - 0, // depth_output_gemm3d - false, // reinterpret_input_as_3d - fc_info.retain_internal_weights, // retain_internal_weights - gemmlowp_output_stage, // gemmlowp_output_stage - fc_info.fp_mixed_precision, // fp_mixed_precision - false, // fast_math - true, // broadcast_bias - ActivationLayerInfo()); // activation_info - - if(is_data_type_quantized_asymmetric(src.data_type())) + // If weights are dynamic, data is not batched, and bias is nullptr validate using matmul. + const bool weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true; + const bool use_matmul = !weights.are_values_constant() && !weights_reshaped && !(dst.dimension(1) > 1) && (bias != nullptr); + + if(use_matmul) { - const UniformQuantizationInfo iq_info = src.quantization_info().uniform(); - const UniformQuantizationInfo wq_info = weights.quantization_info().uniform(); - - // Since we need negative offsets for computing convolution, we need to change QuantizationInfo() - // Extract and negate src and weights offset - const QuantizationInfo src_quantization_info(iq_info.scale, -iq_info.offset); - const QuantizationInfo weights_quantization_info(wq_info.scale, -wq_info.offset); - - // Validate gemmlowp function - ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixMultiplyCore::validate(&src.clone()->set_quantization_info(src_quantization_info), - &weights.clone()->set_quantization_info(weights_quantization_info), - bias, - &dst, - gemm_info)); + MatMulInfo m_info{}; + m_info.adj_rhs(fc_info.transpose_weights); + + // Note: Currently, shape is [M, B0, B1] + // LHS is reshaped here to match ClMatMul expectations of batch index in format - [M, 1, B0, B1, .. ] + TensorInfo lhs_to_use{ src }; + lhs_to_use.set_tensor_shape(get_reshaped_matmul_tensor(src.tensor_shape())); + + // Operator level validation. + ARM_COMPUTE_RETURN_ON_ERROR(ClMatMul::validate(&lhs_to_use, &weights, &dst, m_info, fc_info.activation_info)); } else { - ARM_COMPUTE_RETURN_ON_ERROR(ClGemm::validate(&src, &weights, bias, &dst, 1.f, 1.f, gemm_info)); + GEMMLowpOutputStageInfo gemmlowp_output_stage; + ARM_COMPUTE_RETURN_ON_ERROR(construct_gemmlowp_output_stage(src, weights, dst, gemmlowp_output_stage, fc_info.activation_info)); + + const GEMMInfo &gemm_info = GEMMInfo(false, // is_a_reshaped + false, // is_b_reshaped + true, // reshape_b_only_on_first_run + 0, // depth_output_gemm3d + false, // reinterpret_input_as_3d + fc_info.retain_internal_weights, // retain_internal_weights + gemmlowp_output_stage, // gemmlowp_output_stage + fc_info.fp_mixed_precision, // fp_mixed_precision + false, // fast_math + true, // broadcast_bias + ActivationLayerInfo()); // activation_info + + if(is_data_type_quantized_asymmetric(src.data_type())) + { + const UniformQuantizationInfo iq_info = src.quantization_info().uniform(); + const UniformQuantizationInfo wq_info = weights.quantization_info().uniform(); + + // Since we need negative offsets for computing convolution, we need to change QuantizationInfo() + // Extract and negate src and weights offset + const QuantizationInfo src_quantization_info(iq_info.scale, -iq_info.offset); + const QuantizationInfo weights_quantization_info(wq_info.scale, -wq_info.offset); + + // Validate gemmlowp function + ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixMultiplyCore::validate(&src.clone()->set_quantization_info(src_quantization_info), + &weights.clone()->set_quantization_info(weights_quantization_info), + bias, + &dst, + gemm_info)); + } + else + { + ARM_COMPUTE_RETURN_ON_ERROR(ClGemm::validate(&src, &weights, bias, &dst, 1.f, 1.f, gemm_info)); + } } return Status{}; @@ -148,6 +180,8 @@ ClFullyConnected::ClFullyConnected() _reshape_weights(nullptr), _mm_gemm(nullptr), _mm_gemmlowp(nullptr), + _matmul_native_kernel(nullptr), + _matmul_lowp_native_kernel(nullptr), _aux_mem(Count) { } @@ -157,50 +191,85 @@ ClFullyConnected::~ClFullyConnected() = default; void ClFullyConnected::configure_mm(const CLCompileContext &compile_context, ITensorInfo *src, ITensorInfo *weights, ITensorInfo *bias, ITensorInfo *dst, const FullyConnectedLayerInfo &fc_info) { - GEMMLowpOutputStageInfo gemmlowp_output_stage; - construct_gemmlowp_output_stage(*src, *weights, *dst, gemmlowp_output_stage, fc_info.activation_info); - - const GEMMInfo &gemm_info = GEMMInfo(false, // is_a_reshaped - false, // is_b_reshaped - !_dynamic_weights, // reshape_b_only_on_first_run - 0, // depth_output_gemm3d - false, // reinterpret_input_as_3d - fc_info.retain_internal_weights, // retain_internal_weights - gemmlowp_output_stage, // gemmlowp_output_stage - fc_info.fp_mixed_precision, // fp_mixed_precision - false, // fast_math - true, // broadcast_bias - fc_info.activation_info); // activation_info - - if(_is_quantized) + // If weights are dynamic, configure matmul operator - else use gemm + if(_use_matmul) { - // Since we need negative offsets for computing convolution, we need to change QuantizationInfo() - // Extract and negate input and weights offset - const QuantizationInfo src_quantization_info = src->quantization_info(); - const QuantizationInfo weights_quantization_info = weights->quantization_info(); - - TensorInfo src_info = src->clone()->set_quantization_info(src_quantization_info); - TensorInfo weights_info = weights->clone()->set_quantization_info(weights_quantization_info); - - src_info.set_quantization_info(QuantizationInfo(src_quantization_info.uniform().scale, -src_quantization_info.uniform().offset)); - weights_info.set_quantization_info(QuantizationInfo(weights_quantization_info.uniform().scale, -weights_quantization_info.uniform().offset)); - - // Configure gemmlowp function - _mm_gemmlowp = std::make_unique(); - _mm_gemmlowp->configure(compile_context, &src_info, &weights_info, bias, dst, gemm_info); + // Transpose RHS as _are_weights_reshaped == false when mat_mul is used. + const MatMulInfo mat_info = MatMulInfo().adj_rhs(fc_info.transpose_weights); + + // Note: MatMul does not need offset negation unlike gemm + // 1. Change shape when calling matmul to fit batch expectations. + _lhs_to_use = *src->clone(); + _lhs_to_use.set_tensor_shape(get_reshaped_matmul_tensor(_lhs_to_use.tensor_shape())); // Collapse all dims > 2 into final dimension. + _is_quantized = is_data_type_quantized_asymmetric(_lhs_to_use.data_type()); + + // 2. Call kernel for matmul directly. + const GPUTarget gpu_target = CLScheduler::get().target(); + std::unique_ptr kernel_config = cl_matmul::ClMatMulNativeKernelConfigurationFactory::create(gpu_target); + + // Configure relevant matmul kernel + MatMulKernelInfo kernel_info = kernel_config->configure(src, weights, mat_info); + if(_is_quantized) + { + _matmul_lowp_native_kernel = std::make_unique(); + _matmul_lowp_native_kernel->set_target(gpu_target); + _matmul_lowp_native_kernel->configure(compile_context, src, weights, dst, kernel_info, fc_info.activation_info); + } + else + { + _matmul_native_kernel = std::make_unique(); + _matmul_native_kernel->set_target(gpu_target); + _matmul_native_kernel->configure(compile_context, src, weights, dst, kernel_info, fc_info.activation_info); + } } else { - // Configure matrix multiply kernel - _mm_gemm = std::make_unique(); - _mm_gemm->configure(compile_context, src, weights, bias, dst, 1.f, 1.f, gemm_info); + // Configure GEMM + GEMMLowpOutputStageInfo gemmlowp_output_stage; + construct_gemmlowp_output_stage(*src, *weights, *dst, gemmlowp_output_stage, fc_info.activation_info); + + const GEMMInfo &gemm_info = GEMMInfo(false, // is_a_reshaped + false, // is_b_reshaped + !_dynamic_weights, // reshape_b_only_on_first_run + 0, // depth_output_gemm3d + false, // reinterpret_input_as_3d + fc_info.retain_internal_weights, // retain_internal_weights + gemmlowp_output_stage, // gemmlowp_output_stage + fc_info.fp_mixed_precision, // fp_mixed_precision + false, // fast_math + true, // broadcast_bias + fc_info.activation_info); // activation_info + + if(_is_quantized) + { + // Since we need negative offsets for computing convolution, we need to change QuantizationInfo() + // Extract and negate input and weights offset + const QuantizationInfo src_quantization_info = src->quantization_info(); + const QuantizationInfo weights_quantization_info = weights->quantization_info(); + + TensorInfo src_info = src->clone()->set_quantization_info(src_quantization_info); + TensorInfo weights_info = weights->clone()->set_quantization_info(weights_quantization_info); + + src_info.set_quantization_info(QuantizationInfo(src_quantization_info.uniform().scale, -src_quantization_info.uniform().offset)); + weights_info.set_quantization_info(QuantizationInfo(weights_quantization_info.uniform().scale, -weights_quantization_info.uniform().offset)); + + // Configure gemmlowp function + _mm_gemmlowp = std::make_unique(); + _mm_gemmlowp->configure(compile_context, &src_info, &weights_info, bias, dst, gemm_info); + } + else + { + // Configure matrix multiply kernel + _mm_gemm = std::make_unique(); + _mm_gemm->configure(compile_context, src, weights, bias, dst, 1.f, 1.f, gemm_info); + } } } void ClFullyConnected::configure_conv_fc(const CLCompileContext &compile_context, ITensorInfo *src, ITensorInfo *weights, ITensorInfo *bias, ITensorInfo *dst, const FullyConnectedLayerInfo &fc_info) { - ARM_COMPUTE_ERROR_ON((weights->dimension(1) != (src->dimension(0) * src->dimension(1) * src->dimension(2)))); + ARM_COMPUTE_ERROR_ON((weights->dimension((_use_matmul) ? 0 : 1) != (src->dimension(0) * src->dimension(1) * src->dimension(2)))); // If the fully connected layer is called after a convolution layer, the input tensor must be linearized @@ -211,6 +280,7 @@ void ClFullyConnected::configure_conv_fc(const CLCompileContext &compile_context _flatten = std::make_unique(); _flatten->configure(compile_context, src, &_flattened_src); + // Note: if flatten has > 1 dimensions after, these dimensions are batch // Configure matrix multiply kernel configure_mm(compile_context, &_flattened_src, weights, bias, dst, fc_info); } @@ -218,7 +288,8 @@ void ClFullyConnected::configure_conv_fc(const CLCompileContext &compile_context void ClFullyConnected::configure_fc_fc(const CLCompileContext &compile_context, ITensorInfo *src, ITensorInfo *weights, ITensorInfo *bias, ITensorInfo *dst, const FullyConnectedLayerInfo &fc_info) { - ARM_COMPUTE_ERROR_ON(src->dimension(0) != weights->dimension(1)); + // Compare first dimension when using matmul, as it performs transpose operation + ARM_COMPUTE_ERROR_ON(src->dimension(0) != weights->dimension((_use_matmul) ? 0 : 1)); // Configure matrix multiply kernel configure_mm(compile_context, src, weights, bias, dst, fc_info); @@ -240,7 +311,13 @@ void ClFullyConnected::configure(const CLCompileContext &compile_context, ITenso _is_prepared = fc_info.retain_internal_weights; _weights_to_use = TensorInfo(*weights); _weights_to_use_idx = ACL_SRC_1; - _dynamic_weights = !weights->are_values_constant() && !_are_weights_reshaped; + + // When using dynamic weights - use matmul kernels. + // Note: We don't appear to support dynamic weights with pre-reshaped RHS. + // Note: No matmul with biases for the moment. + const bool is_batched_fc_layer = dst->dimension(1) > 1; + _dynamic_weights = !weights->are_values_constant() && !_are_weights_reshaped; + _use_matmul = _dynamic_weights && !is_batched_fc_layer && !(biases); // With the Fully Connected layer we can have 4 different cases: // 1) Convolution layer -> Fully Connected layer without batches @@ -249,7 +326,6 @@ void ClFullyConnected::configure(const CLCompileContext &compile_context, ITenso // 4) Fully Connected layer -> Fully Connected layer with batches // Check if we have a fully connected layer with batches - const bool is_batched_fc_layer = dst->dimension(1) > 1; if(is_batched_fc_layer) { _is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(src->tensor_shape().cbegin() + 3, @@ -264,7 +340,8 @@ void ClFullyConnected::configure(const CLCompileContext &compile_context, ITenso ITensorInfo *weights_used = weights; // Reshape weights if needed - if(!_are_weights_reshaped) + // Not needed when matmul is in use - MatMul has transpose RHS flags. + if(!_are_weights_reshaped && !_use_matmul) { // Reshape the weights _reshape_weights = std::make_unique(); @@ -302,39 +379,47 @@ void ClFullyConnected::configure(const CLCompileContext &compile_context, ITenso // Update TensorInfo of final weights used (Need to be done in the end due to padding expansion) _weights_to_use = *weights_used; - // Set auxiliary memory requirements - auto gemm_mem_req = (_is_quantized) ? _mm_gemmlowp->workspace() : _mm_gemm->workspace(); - for(unsigned int i = 0; i < gemm_mem_req.size(); ++i) + if(_use_matmul) { - _aux_mem[i] = gemm_mem_req[i]; - } - if(_aux_mem[1].size > 0 || _aux_mem[2].size > 0) // Persistent weights memory on GEMMs - { - // Release permuted weights at the of prepare as they are further transposed by the assembly dispatch - // Keep all the auxiliary tensors in case of dynamic weights as they are recalculated every time - _aux_mem[TransposedWeights] = MemoryInfo( - offset_int_vec(TransposedWeights), - _dynamic_weights ? MemoryLifetime::Temporary : MemoryLifetime::Prepare, - _reshaped_weights.total_size()); - _aux_mem[ConvertedWeights] = MemoryInfo( - offset_int_vec(ConvertedWeights), - _dynamic_weights ? MemoryLifetime::Temporary : MemoryLifetime::Prepare, - _converted_weights.total_size()); + // Note : MatMul does not use transpose and does not need auxillary memory, so only converted weights are added to aux_mem + _aux_mem[ConvertedWeights] = MemoryInfo(offset_int_vec(ConvertedWeights), MemoryLifetime::Temporary, _converted_weights.total_size()); } else { - // Release permuted weights at the of prepare as they are further transposed by the assembly dispatch - const auto transposed_wei_lft = (_weights_to_use_idx == offset_int_vec(TransposedWeights)) ? MemoryLifetime::Persistent : MemoryLifetime::Prepare; - const auto converted_wei_lft = (_weights_to_use_idx == offset_int_vec(ConvertedWeights)) ? MemoryLifetime::Persistent : MemoryLifetime::Prepare; - - _aux_mem[TransposedWeights] = MemoryInfo( - offset_int_vec(TransposedWeights), - _dynamic_weights ? MemoryLifetime::Temporary : transposed_wei_lft, - _reshaped_weights.total_size()); - _aux_mem[ConvertedWeights] = MemoryInfo( - offset_int_vec(ConvertedWeights), - _dynamic_weights ? MemoryLifetime::Temporary : converted_wei_lft, - _converted_weights.total_size()); + // Set auxiliary memory requirements for gemm operators + auto gemm_mem_req = (_is_quantized) ? _mm_gemmlowp->workspace() : _mm_gemm->workspace(); + for(unsigned int i = 0; i < gemm_mem_req.size(); ++i) + { + _aux_mem[i] = gemm_mem_req[i]; + } + if(_aux_mem[1].size > 0 || _aux_mem[2].size > 0) // Persistent weights memory on GEMMs + { + // Release permuted weights at the of prepare as they are further transposed by the assembly dispatch + // Keep all the auxiliary tensors in case of dynamic weights as they are recalculated every time + _aux_mem[TransposedWeights] = MemoryInfo( + offset_int_vec(TransposedWeights), + _dynamic_weights ? MemoryLifetime::Temporary : MemoryLifetime::Prepare, + _reshaped_weights.total_size()); + _aux_mem[ConvertedWeights] = MemoryInfo( + offset_int_vec(ConvertedWeights), + _dynamic_weights ? MemoryLifetime::Temporary : MemoryLifetime::Prepare, + _converted_weights.total_size()); + } + else + { + // Release permuted weights at the of prepare as they are further transposed by the assembly dispatch + const auto transposed_wei_lft = (_weights_to_use_idx == offset_int_vec(TransposedWeights)) ? MemoryLifetime::Persistent : MemoryLifetime::Prepare; + const auto converted_wei_lft = (_weights_to_use_idx == offset_int_vec(ConvertedWeights)) ? MemoryLifetime::Persistent : MemoryLifetime::Prepare; + + _aux_mem[TransposedWeights] = MemoryInfo( + offset_int_vec(TransposedWeights), + _dynamic_weights ? MemoryLifetime::Temporary : transposed_wei_lft, + _reshaped_weights.total_size()); + _aux_mem[ConvertedWeights] = MemoryInfo( + offset_int_vec(ConvertedWeights), + _dynamic_weights ? MemoryLifetime::Temporary : converted_wei_lft, + _converted_weights.total_size()); + } } _aux_mem[FlattenedSrc] = MemoryInfo(offset_int_vec(FlattenedSrc), MemoryLifetime::Temporary, _flattened_src.total_size()); } @@ -349,8 +434,15 @@ Status ClFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *wei ARM_COMPUTE_RETURN_ERROR_ON(fc_info.activation_info.enabled() && is_data_type_quantized(src->data_type()) && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::BOUNDED_RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU); - bool weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true; - bool is_fc_after_conv = true; + const bool weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true; + bool is_fc_after_conv = true; + + // When using dynamic weights - use matmul kernels. + // Note: MatMul does not support broadcasting or biases so fallback with batched cases or when biases != nullptr. + // Note: Pre-Shaped RHS is a deprecated use case and is therefore not supported with matmul. + const bool dynamic_weights = !weights->are_values_constant() && !weights_reshaped; + const bool is_batched_fc_layer = dst->dimension(1) > 1; + const bool use_matmul = dynamic_weights && !is_batched_fc_layer && (biases != nullptr); const ITensorInfo &flatten_src = TensorInfo(src->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_flatten_shape(src)).set_data_layout(DataLayout::NCHW)); const ITensorInfo &reshaped_weights = TensorInfo(weights->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_transposed_shape(*weights))); @@ -378,8 +470,7 @@ Status ClFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *wei } } - // Check if we have a fully connected layer with batches - const bool is_batched_fc_layer = dst->dimension(1) > 1; + // Check if FC is after conv (flatten kernel is run in case where FC is after conv.) if(is_batched_fc_layer) { is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(src->tensor_shape().cbegin() + 3, @@ -391,7 +482,7 @@ Status ClFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *wei is_fc_after_conv = src->num_dimensions() > 1; } - if(!weights_reshaped) + if(!weights_reshaped && !use_matmul) { // Validate reshape weights kernel ARM_COMPUTE_RETURN_ON_ERROR(ClTranspose::validate(weights, &reshaped_weights)); @@ -411,7 +502,14 @@ Status ClFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *wei if(is_fc_after_conv) { // Fully Connected layer after a Convolution Layer without batches - ARM_COMPUTE_RETURN_ERROR_ON((weights_to_use->dimension(1) != (src->dimension(0) * src->dimension(1) * src->dimension(2)))); + if(use_matmul) + { + ARM_COMPUTE_RETURN_ERROR_ON((weights_to_use->dimension(0) != (src->dimension(0) * src->dimension(1) * src->dimension(2)))); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON((weights_to_use->dimension(1) != (src->dimension(0) * src->dimension(1) * src->dimension(2)))); + } // Validate flatten kernel ARM_COMPUTE_RETURN_ON_ERROR(ClFlatten::validate(src, &flatten_src)); @@ -420,7 +518,7 @@ Status ClFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *wei else { // Fully Connected layer after a Fully Connected Layer without batches - ARM_COMPUTE_RETURN_ERROR_ON(src->dimension(0) != weights_to_use->dimension(1)); + ARM_COMPUTE_RETURN_ERROR_ON(src->dimension(0) != weights_to_use->dimension((use_matmul) ? 0 : 1)); } // Validate matrix multiply kernel @@ -457,14 +555,30 @@ void ClFullyConnected::run(ITensorPack &tensors) gemm_pack.add_const_tensor(ACL_SRC_1, weights.get()); } - // Run matrix multiply - if(_is_quantized) + // Run MatMul Op + if(_use_matmul) { - _mm_gemmlowp->run(gemm_pack); + // Run matmul kernels for matrix multiplication + if(_is_quantized) + { + CLScheduler::get().enqueue_op(*_matmul_lowp_native_kernel, gemm_pack, true); + } + else + { + CLScheduler::get().enqueue_op(*_matmul_native_kernel, gemm_pack, true); + } } else { - _mm_gemm->run(gemm_pack); + // Run matrix multiply + if(_is_quantized) + { + _mm_gemmlowp->run(gemm_pack); + } + else + { + _mm_gemm->run(gemm_pack); + } } } @@ -486,7 +600,7 @@ void ClFullyConnected::prepare(ITensorPack &tensors) const ITensor *cur_weights = weights; // Reshape of the weights if needed - if(!_are_weights_reshaped) + if(!_are_weights_reshaped && !_use_matmul) { // Run reshape weights kernel and mark weights as unused ITensorPack transpose_pack{ { ACL_SRC, weights }, { ACL_DST, reshaped_weights.get() } }; @@ -509,15 +623,19 @@ void ClFullyConnected::prepare(ITensorPack &tensors) ITensorPack gemm_pack = tensors; gemm_pack.add_const_tensor(ACL_SRC_1, cur_weights); - // Prepare GEMM prepare and release unused weights - if(!_is_quantized) + // Prepare GEMM prepare and release unused weights (If not using matmul) + if(!_use_matmul) { - _mm_gemm->prepare(gemm_pack); - } - else - { - _mm_gemmlowp->prepare(gemm_pack); + if(!_is_quantized) + { + _mm_gemm->prepare(gemm_pack); + } + else + { + _mm_gemmlowp->prepare(gemm_pack); + } } + _is_prepared = true; } } diff --git a/src/gpu/cl/operators/ClFullyConnected.h b/src/gpu/cl/operators/ClFullyConnected.h index 11a59b2359..5dc68c1bbe 100644 --- a/src/gpu/cl/operators/ClFullyConnected.h +++ b/src/gpu/cl/operators/ClFullyConnected.h @@ -42,7 +42,12 @@ class ClFlatten; class ClGemm; class ClGemmLowpMatrixMultiplyCore; class ClTranspose; - +// Kernel Forward Declarations +namespace kernels +{ +class ClMatMulNativeKernel; +class ClMatMulLowpNativeKernel; +} /** Basic function to compute a Fully Connected layer on OpenCL. This function calls the following OpenCL kernels: * * -# @ref opencl::kernels::ClIm2ColKernel (called when the input comes from a convolutional layer) @@ -119,12 +124,19 @@ private: std::unique_ptr _mm_gemm; std::unique_ptr _mm_gemmlowp; + std::unique_ptr _matmul_native_kernel; + std::unique_ptr _matmul_lowp_native_kernel; + experimental::MemoryRequirements _aux_mem{}; TensorInfo _flattened_src{}; TensorInfo _converted_weights{}; TensorInfo _reshaped_weights{}; + // Saved tensor shapes for reshaping when using matmul + TensorShape _lhs_shape_original{}; + TensorInfo _lhs_to_use{}; + TensorInfo _weights_to_use{}; int _weights_to_use_idx{ ACL_SRC_1 }; @@ -134,10 +146,11 @@ private: bool _is_quantized{ false }; bool _is_prepared{ false }; bool _dynamic_weights{ false }; + bool _use_matmul{ false }; #ifdef ARM_COMPUTE_ASSERTS_ENABLED - int _asrt_run_count{}; - int _asrt_prepare_count{}; + int _asrt_run_count {}; + int _asrt_prepare_count{}; #endif // ARM_COMPUTE_ASSERTS_ENABLED }; } // namespace opencl diff --git a/src/runtime/heuristics/matmul_native/IClMatMulNativeKernelConfig.h b/src/runtime/heuristics/matmul_native/IClMatMulNativeKernelConfig.h index 203f68c253..60e838c5cb 100644 --- a/src/runtime/heuristics/matmul_native/IClMatMulNativeKernelConfig.h +++ b/src/runtime/heuristics/matmul_native/IClMatMulNativeKernelConfig.h @@ -111,6 +111,6 @@ public: protected: GPUTarget _target; }; -} // namespace opencl +} // namespace cl_matmul } // namespace arm_compute #endif /* SRC_RUNTIME_HEURISTICS_MATMUL_NATIVE_ICLMATMULNATIVEKERNELCONFIG */ diff --git a/tests/validation/CL/FullyConnectedLayer.cpp b/tests/validation/CL/FullyConnectedLayer.cpp index 9213ab541d..474a87dd1c 100644 --- a/tests/validation/CL/FullyConnectedLayer.cpp +++ b/tests/validation/CL/FullyConnectedLayer.cpp @@ -131,6 +131,8 @@ template using CLFullyConnectedLayerMixedDataLayoutFixture = FullyConnectedLayerValidationFixture; template using CLFullyConnectedLayerDynamicWeightsFixture = FullyConnectedWithDynamicWeightsFixture; +template +using CLFullyConnectedNoBiasFixture = FullyConnectedDynamicNoBiasFixture; TEST_SUITE(Float) TEST_SUITE(FP16) @@ -151,9 +153,9 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLFullyConnectedLayerFixture, framework:: validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num); } FIXTURE_DATA_TEST_CASE(RunDynamicWeights, CLFullyConnectedLayerDynamicWeightsFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallFullyConnectedLayerDataset(), - framework::dataset::make("DataType", DataType::F16)), - framework::dataset::make("ActivationInfo", ActivationLayerInfo())), - framework::dataset::make("WeightsReshaped", { false, true }))) + framework::dataset::make("DataType", DataType::F16)), + framework::dataset::make("ActivationInfo", ActivationLayerInfo())), + framework::dataset::make("WeightsReshaped", { false, true }))) { } TEST_SUITE_END() @@ -179,9 +181,15 @@ FIXTURE_DATA_TEST_CASE(RunMixedDataLayout, CLFullyConnectedLayerMixedDataLayoutF validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0, abs_tolerance_f32); } FIXTURE_DATA_TEST_CASE(RunDynamicWeights, CLFullyConnectedLayerDynamicWeightsFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallFullyConnectedLayerDataset(), - framework::dataset::make("DataType", DataType::F32)), - framework::dataset::make("ActivationInfo", ActivationLayerInfo())), - framework::dataset::make("WeightsReshaped", { false, true }))) + framework::dataset::make("DataType", DataType::F32)), + framework::dataset::make("ActivationInfo", ActivationLayerInfo())), + framework::dataset::make("WeightsReshaped", { false, true }))) +{ +} +FIXTURE_DATA_TEST_CASE(RunDynamicNoBias, CLFullyConnectedNoBiasFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallFullyConnectedLayerDataset(), + framework::dataset::make("DataType", DataType::F32)), + framework::dataset::make("ActivationInfo", { ActivationLayerInfo(), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU) })), + framework::dataset::make("WeightsReshaped", { false }))) { } FIXTURE_DATA_TEST_CASE(RunLarge, CLFullyConnectedLayerFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeFullyConnectedLayerDataset(), FullyConnectedParameters), @@ -230,9 +238,9 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLFullyConnectedLayerQuantizedFixture, validate(CLAccessor(_target), _reference, tolerance_qasymm8); } FIXTURE_DATA_TEST_CASE(RunDynamicWeights, CLFullyConnectedLayerDynamicWeightsFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallFullyConnectedLayerDataset(), - framework::dataset::make("DataType", DataType::QASYMM8)), - framework::dataset::make("ActivationInfo", ActivationLayerInfo())), - framework::dataset::make("WeightsReshaped", { false /* COMPMID-6000: Support FullyConnected with quantized dynamic weights already reshaped */ }))) + framework::dataset::make("DataType", DataType::QASYMM8)), + framework::dataset::make("ActivationInfo", ActivationLayerInfo())), + framework::dataset::make("WeightsReshaped", { false /* COMPMID-6000: Support FullyConnected with quantized dynamic weights already reshaped */ }))) { } TEST_SUITE_END() /* QASYMM8 */ @@ -259,9 +267,15 @@ FIXTURE_DATA_TEST_CASE(RunMixedDataLayout, CLFullyConnectedLayerQuantizedMixedDa validate(CLAccessor(_target), _reference, tolerance_qasymm8); } FIXTURE_DATA_TEST_CASE(RunDynamicWeights, CLFullyConnectedLayerDynamicWeightsFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallFullyConnectedLayerDataset(), - framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), - framework::dataset::make("ActivationInfo", ActivationLayerInfo())), - framework::dataset::make("WeightsReshaped", { false /* COMPMID-6000: Support FullyConnected with quantized dynamic weights already reshaped */ }))) + framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), + framework::dataset::make("ActivationInfo", ActivationLayerInfo())), + framework::dataset::make("WeightsReshaped", { false /* COMPMID-6000: Support FullyConnected with quantized dynamic weights already reshaped */ }))) +{ +} +FIXTURE_DATA_TEST_CASE(RunDynamicNoBias, CLFullyConnectedNoBiasFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallFullyConnectedLayerDataset(), + framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), + framework::dataset::make("ActivationInfo", ActivationLayerInfo())), + framework::dataset::make("WeightsReshaped", { false /* COMPMID-6000: Support FullyConnected with quantized dynamic weights already reshaped */ }))) { } TEST_SUITE_END() // QASYMM8_SIGNED diff --git a/tests/validation/fixtures/FullyConnectedLayerFixture.h b/tests/validation/fixtures/FullyConnectedLayerFixture.h index 75bef144ad..e13c01d1e2 100644 --- a/tests/validation/fixtures/FullyConnectedLayerFixture.h +++ b/tests/validation/fixtures/FullyConnectedLayerFixture.h @@ -335,9 +335,9 @@ private: void validate_with_tolerance(TensorType &target, SimpleTensor &ref) { - constexpr AbsoluteTolerance abs_tolerance_f16(0.3f); + constexpr AbsoluteTolerance abs_tolerance_f16(0.3f); const RelativeTolerance rel_tolerance_f16(half_float::half(0.2f)); - constexpr float tolerance_num_f16 = 0.07f; + constexpr float tolerance_num_f16 = 0.07f; validate(AccessorType(target), ref, rel_tolerance_f16, tolerance_num_f16, abs_tolerance_f16); } @@ -360,36 +360,36 @@ public: template void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape, - DataType data_type, ActivationLayerInfo activation_info, bool constant_weights, bool constant_bias, bool weights_reshaped) + DataType data_type, ActivationLayerInfo activation_info, bool constant_weights, bool constant_bias, bool weights_reshaped, bool remove_bias = false) { _data_type = data_type; - const bool is_quantized = is_data_type_quantized(data_type); - + const bool is_quantized = is_data_type_quantized(data_type); const DataType bias_data_type = (is_quantized) ? DataType::S32 : data_type; const QuantizationInfo src_qinfo = is_quantized ? QuantizationInfo(0.1f, 10) : QuantizationInfo(); const QuantizationInfo weights_qinfo = is_quantized ? QuantizationInfo(0.3f, 20) : QuantizationInfo(); const QuantizationInfo dst_qinfo = is_quantized ? QuantizationInfo(0.2f, 5) : QuantizationInfo(); - // Setup tensor meta-data + // Configure TensorInfo Objects const TensorInfo src_info(src_shape, 1, data_type, src_qinfo); - _src.allocator()->init(src_info); + const TensorInfo dst_info(dst_shape, 1, data_type, dst_qinfo); + TensorInfo bias_info(bias_shape, 1, bias_data_type); + TensorInfo wei_info(weights_shape, 1, data_type, weights_qinfo); - TensorInfo wei_info(weights_shape, 1, data_type, weights_qinfo); if(!constant_weights && weights_reshaped) { const TensorShape tr_weights_shape{ weights_shape[1], weights_shape[0] }; wei_info.set_tensor_shape(tr_weights_shape); } wei_info.set_are_values_constant(constant_weights); - _weights.allocator()->init(wei_info); - - TensorInfo bias_info(bias_shape, 1, bias_data_type); bias_info.set_are_values_constant(constant_bias); - _bias.allocator()->init(bias_info); - const TensorInfo dst_info(dst_shape, 1, data_type, dst_qinfo); + // Initialise Tensors + _src.allocator()->init(src_info); + _weights.allocator()->init(wei_info); + if(!remove_bias) + _bias.allocator()->init(bias_info); _dst.allocator()->init(dst_info); // Configure FC layer and mark the weights as non constant @@ -401,12 +401,13 @@ public: fc_info.transpose_weights = !weights_reshaped; } FunctionType fc; - fc.configure(&_src, &_weights, &_bias, &_dst, fc_info); + fc.configure(&_src, &_weights, (remove_bias) ? nullptr : &_bias, &_dst, fc_info); // Allocate all the tensors _src.allocator()->allocate(); _weights.allocator()->allocate(); - _bias.allocator()->allocate(); + if(!remove_bias) + _bias.allocator()->allocate(); _dst.allocator()->allocate(); // Run multiple iterations with different inputs @@ -424,11 +425,20 @@ public: fill(AccessorType(_weights), 1); fill(weights, 1); } - if(constant_bias) + if(constant_bias && !remove_bias) { fill(AccessorType(_bias), 2); fill(bias, 2); } + // To remove bias, fill with 0 + if(remove_bias && is_quantized) + { + library->fill_tensor_value(bias, 0); + } + else if(remove_bias) + { + library->fill_tensor_value(bias, (float)0.0); + } for(int i = 0; i < num_iterations; ++i) { @@ -446,7 +456,7 @@ public: fill(AccessorType(_weights), randomizer_offset + 1); } } - if(!constant_bias) + if(!constant_bias && !remove_bias) { fill(AccessorType(_bias), randomizer_offset + 2); } @@ -462,7 +472,7 @@ public: { fill(weights, randomizer_offset + 1); } - if(!constant_bias) + if(!constant_bias && !remove_bias) { fill(bias, randomizer_offset + 2); } @@ -491,7 +501,20 @@ public: DataType data_type, ActivationLayerInfo activation_info, bool weights_reshaped) { FullyConnectedWithDynamicTensorsFixture::setup(src_shape, weights_shape, bias_shape, - dst_shape, data_type, activation_info, false, true, weights_reshaped); + dst_shape, data_type, activation_info, false, true, weights_reshaped, false); + } +}; + +template +class FullyConnectedDynamicNoBiasFixture : public FullyConnectedWithDynamicTensorsFixture +{ +public: + template + void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape, + DataType data_type, ActivationLayerInfo activation_info, bool weights_reshaped) + { + FullyConnectedWithDynamicTensorsFixture::setup(src_shape, weights_shape, bias_shape, + dst_shape, data_type, activation_info, false, true, weights_reshaped, true); } }; @@ -504,7 +527,7 @@ public: DataType data_type, ActivationLayerInfo activation_info) { FullyConnectedWithDynamicTensorsFixture::setup(src_shape, weights_shape, bias_shape, - dst_shape, data_type, activation_info, true, false, false /* weights_reshaped (not used) */); + dst_shape, data_type, activation_info, true, false, false, false /* weights_reshaped (not used) */); } }; } // namespace validation -- cgit v1.2.1