From f3622becf1f0d6bf5147ebb7d6d0f14d5252860a Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Mon, 29 Jul 2019 14:27:16 +0100 Subject: COMPMID-1979: Fuse Activation Function in CLGEMM - part 4 Fused activation function in CLGEMM Change-Id: I644fdf09349325c0b3a2cd5fef2a3ea2c974149d Signed-off-by: Gian Marco Iodice Reviewed-on: https://review.mlplatform.org/c/1640 Comments-Addressed: Arm Jenkins Reviewed-by: Georgios Pinitas Tested-by: Arm Jenkins --- arm_compute/core/Types.h | 19 +- arm_compute/runtime/CL/functions/CLGEMM.h | 2 - .../runtime/CL/functions/CLGEMMConvolutionLayer.h | 29 ++-- examples/cl_cache.cpp | 8 +- src/runtime/CL/functions/CLGEMM.cpp | 91 +++------- .../CL/functions/CLGEMMConvolutionLayer.cpp | 193 ++++++++++++--------- tests/datasets/LargeGEMMDataset.h | 28 +-- tests/datasets/SmallGEMMDataset.h | 24 +-- tests/validation/CL/GEMMMatrixMultiply.cpp | 2 +- .../CL/GEMMMatrixMultiplyInterleavedTransposed.cpp | 2 +- tests/validation/fixtures/GEMMFixture.h | 38 ++-- 11 files changed, 221 insertions(+), 215 deletions(-) diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index b4d94eced4..2c17f273a5 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -1775,7 +1775,8 @@ public: _gemmlowp_output_stage(), _fp_mixed_precision(false), _broadcast_bias(false), - _pretranpose_B(true) + _pretranpose_B(true), + _activation_info() { } /** Constructor @@ -1791,9 +1792,11 @@ public: * @param[in] gemmlowp_output_stage (Optional) GEMMLowp Output stage info * @param[in] fp_mixed_precision (Optional) Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy. * @param[in] broadcast_bias (Optional) Broadcast the shape of the bias tensor from a vector to a matrix. + * @param[in] activation_info (Optional) Activation to apply after the matrix multiplication */ GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 0, bool reinterpret_input_as_3d = false, bool retain_internal_weights = false, - GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(), bool fp_mixed_precision = false, bool broadcast_bias = false) noexcept + GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(), bool fp_mixed_precision = false, bool broadcast_bias = false, + const ActivationLayerInfo &activation_info = ActivationLayerInfo()) noexcept : _is_a_reshaped(is_a_reshaped), _is_b_reshaped(is_b_reshaped), _reshape_b_only_on_first_run(reshape_b_only_on_first_run), @@ -1803,7 +1806,8 @@ public: _gemmlowp_output_stage(gemmlowp_output_stage), _fp_mixed_precision(fp_mixed_precision), _broadcast_bias(broadcast_bias), - _pretranpose_B(reshape_b_only_on_first_run) + _pretranpose_B(reshape_b_only_on_first_run), + _activation_info(activation_info) { } /** Flag which specifies if the matrix A has been reshaped @@ -1896,6 +1900,14 @@ public: { _pretranpose_B = flag; } + /** Activation layer to apply after the matrix multiplication + * + * @return ActivationLayerInfo object + */ + ActivationLayerInfo activation_info() const + { + return _activation_info; + } private: bool _is_a_reshaped; @@ -1908,6 +1920,7 @@ private: bool _fp_mixed_precision; bool _broadcast_bias; bool _pretranpose_B; + ActivationLayerInfo _activation_info; }; /** Winograd information */ diff --git a/arm_compute/runtime/CL/functions/CLGEMM.h b/arm_compute/runtime/CL/functions/CLGEMM.h index 8c462fa4cb..e2a92a8a37 100644 --- a/arm_compute/runtime/CL/functions/CLGEMM.h +++ b/arm_compute/runtime/CL/functions/CLGEMM.h @@ -127,7 +127,6 @@ private: CLMemoryGroup _memory_group; CLGEMMMatrixMultiplyKernel _mm_kernel; - CLGEMMMatrixAdditionKernel _ma_kernel; CLGEMMReshapeLHSMatrixKernel _reshape_lhs_kernel; CLGEMMReshapeRHSMatrixKernel _reshape_rhs_kernel; CLGEMMMatrixMultiplyReshapedKernel _mm_reshaped_kernel; @@ -135,7 +134,6 @@ private: CLTensor _tmp_a; CLTensor _tmp_b; const ICLTensor *_original_b; - bool _run_addition; bool _reshape_b_only_on_first_run; bool _is_prepared; GEMMType _gemm_type; diff --git a/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h b/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h index e9a3f9bf2b..027727c7f7 100644 --- a/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h +++ b/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h @@ -163,8 +163,10 @@ private: * except for input of QASYMM8 type where output should be of S32 type. * @param[in] gemmlowp_output_stage GEMMLowp output stage info * @param[in] gemm_3d_depth Depth of GEMM 3D + * @param[in] act_info Activation to apply after the matrix multiplication */ - void configure_mm(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const GEMMLowpOutputStageInfo &gemmlowp_output_stage, int gemm_3d_depth = 1); + void configure_mm(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const GEMMLowpOutputStageInfo &gemmlowp_output_stage, int gemm_3d_depth, + const ActivationLayerInfo &act_info); /** Static function to check if given info will lead to a valid configuration of @ref CLGEMMConvolutionLayer matrix multiply routines * * @param[in] input Input tensor. Data types supported: QASYMM8/F16/F32. @@ -176,22 +178,21 @@ private: * @param[in] gemmlowp_output_stage GEMMLowp output stage info * @param[in] gemm_3d_depth Depth of GEMM 3D * @param[in] skip_im2col Flag which specifies if im2col has to be skipped. i.e. 1x1 convolution with NHWC data layout. - * @param[in] run_addition Flag which specifies if @ref CLGEMMMatrixMatrixMultiplyAddition to be run. + * @param[in] act_info Activation to apply after the matrix multiplication * * @return a status */ static Status validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const GEMMLowpOutputStageInfo &gemmlowp_output_stage, - int gemm_3d_depth, bool skip_im2col, bool run_addition); + int gemm_3d_depth, bool skip_im2col, const ActivationLayerInfo &act_info); private: - CLMemoryGroup _memory_group; - CLConvolutionLayerReshapeWeights _reshape_weights; - CLIm2ColKernel _im2col_kernel; - CLGEMM _mm_gemm; - CLGEMMLowpMatrixMultiplyCore _mm_gemmlowp; - CLCol2ImKernel _col2im_kernel; - CLActivationLayer _activationlayer_function; - CLSaturatedArithmeticOperationKernel _add_bias_kernel; + CLMemoryGroup _memory_group; + CLConvolutionLayerReshapeWeights _reshape_weights; + CLIm2ColKernel _im2col_kernel; + CLGEMM _mm_gemm; + CLGEMMLowpMatrixMultiplyCore _mm_gemmlowp; + CLCol2ImKernel _col2im_kernel; + CLActivationLayer _activationlayer_function; const ICLTensor *_original_weights; @@ -199,15 +200,11 @@ private: CLTensor _weights_reshaped; CLTensor _gemm_output; - DataLayout _data_layout; - - bool _append_bias; bool _skip_im2col; bool _skip_col2im; bool _is_quantized; - bool _is_activationlayer_enabled; + bool _fuse_activation; bool _is_prepared; - bool _run_addition; }; } // namespace arm_compute #endif /* __ARM_COMPUTE_CLGEMMCONVOLUTIONLAYER_H__ */ diff --git a/examples/cl_cache.cpp b/examples/cl_cache.cpp index 998c4682ba..7d8a515424 100644 --- a/examples/cl_cache.cpp +++ b/examples/cl_cache.cpp @@ -28,8 +28,6 @@ #include "arm_compute/runtime/CL/CLScheduler.h" #include "utils/Utils.h" -#include - using namespace arm_compute; using namespace utils; @@ -46,7 +44,7 @@ public: { std::cout << "Once the program has run and created the file cache.bin, rerun with --restore_cache." << std::endl; CLScheduler::get().default_init(); - auto start_time = std::chrono::high_resolution_clock::now(); + if(argc > 1) { std::string argv1 = argv[1]; @@ -88,10 +86,6 @@ public: permute_nchw.configure(&tensor_nhwc, &tensor_nchw_result, vector_nhwc_to_nchw); tensor_nchw_result.allocator()->allocate(); - auto end_time = std::chrono::high_resolution_clock::now(); - auto time_elapsed = end_time - start_time; - auto time_elapsed_ms = std::chrono::duration_cast(time_elapsed).count(); - std::cout << "Configuration time " << time_elapsed_ms << " ms " << std::endl; // Save the opencl kernels to a file save_program_cache_to_file(); diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp index c0ccd0f451..e78395f1de 100644 --- a/src/runtime/CL/functions/CLGEMM.cpp +++ b/src/runtime/CL/functions/CLGEMM.cpp @@ -48,7 +48,6 @@ using namespace arm_compute::cl_gemm; CLGEMM::CLGEMM(std::shared_ptr memory_manager) : _memory_group(std::move(memory_manager)), _mm_kernel(), - _ma_kernel(), _reshape_lhs_kernel(), _reshape_rhs_kernel(), _mm_reshaped_kernel(), @@ -56,7 +55,6 @@ CLGEMM::CLGEMM(std::shared_ptr memory_manager) _tmp_a(), _tmp_b(), _original_b(nullptr), - _run_addition(false), _reshape_b_only_on_first_run(false), _is_prepared(false), _gemm_type(GEMMType::NATIVE) @@ -118,10 +116,10 @@ void CLGEMM::configure_native(const ICLTensor *a, const ICLTensor *b, const ICLT // Set the target for the kernels _mm_kernel.set_target(gpu_target); - GEMMReshapeInfo reshape_info(m, n, k, 1, 1, gemm_info.depth_output_gemm3d(), gemm_info.reinterpret_input_as_3d()); + GEMMReshapeInfo reshape_info(m, n, k, 1, 1, gemm_info.depth_output_gemm3d(), gemm_info.reinterpret_input_as_3d(), gemm_info.broadcast_bias()); // Configure and tune matrix multiply kernel - _mm_kernel.configure(a, b, c, output, alpha, beta, false, reshape_info, gemm_info.fp_mixed_precision()); + _mm_kernel.configure(a, b, c, output, alpha, beta, false, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info()); // Tune kernel statically CLScheduler::get().tune_kernel_static(_mm_kernel); @@ -162,7 +160,7 @@ void CLGEMM::configure_reshaped_v1(const ICLTensor *a, const ICLTensor *b, const lhs_info.interleave = true; lhs_info.transpose = true; - GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false); + GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false, gemm_info.broadcast_bias()); _memory_group.manage(&_tmp_a); if(!_reshape_b_only_on_first_run) @@ -177,7 +175,7 @@ void CLGEMM::configure_reshaped_v1(const ICLTensor *a, const ICLTensor *b, const _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info); // Configure and tune matrix multiply kernel - _mm_kernel.configure(&_tmp_a, &_tmp_b, c, output, alpha, beta, true, reshape_info, gemm_info.fp_mixed_precision()); + _mm_kernel.configure(&_tmp_a, &_tmp_b, c, output, alpha, beta, true, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info()); CLScheduler::get().tune_kernel_static(_mm_kernel); @@ -200,13 +198,15 @@ void CLGEMM::configure_reshaped_v2(const ICLTensor *a, const ICLTensor *b, const const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); const GPUTarget gpu_target = CLScheduler::get().target(); bool broadcast_bias = gemm_info.broadcast_bias(); - GEMMKernelInfo kernel_info; + + GEMMKernelInfo kernel_info; kernel_info.m = m; kernel_info.n = n; kernel_info.k = k; kernel_info.depth_output_gemm3d = depth_output_gemm3d; kernel_info.reinterpret_input_as_3d = false; kernel_info.broadcast_bias = broadcast_bias; + kernel_info.activation_info = gemm_info.activation_info(); // Set the target for the kernels _reshape_lhs_kernel.set_target(gpu_target); @@ -255,13 +255,15 @@ void CLGEMM::configure_reshaped_only_rhs(const ICLTensor *a, const ICLTensor *b, const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); const GPUTarget gpu_target = CLScheduler::get().target(); bool broadcast_bias = gemm_info.broadcast_bias(); - GEMMKernelInfo kernel_info; + + GEMMKernelInfo kernel_info; kernel_info.m = m; kernel_info.n = n; kernel_info.k = k; kernel_info.depth_output_gemm3d = depth_output_gemm3d; kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d; kernel_info.broadcast_bias = broadcast_bias; + kernel_info.activation_info = gemm_info.activation_info(); // Set the target for the kernels _mm_kernel.set_target(gpu_target); @@ -305,21 +307,12 @@ Status CLGEMM::validate_native(const ITensorInfo *a, const ITensorInfo *b, const const unsigned int n = b->dimension(0); const unsigned int k = a->dimension(0); const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); - const bool add_c = (beta != 0.f && c != nullptr); - const bool is_beta_one = std::abs(1.0f - beta) < 0.00001f; - const bool fuse_add = is_beta_one && (c != nullptr && c->num_dimensions() == 1); - const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d); + const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d, gemm_info.broadcast_bias()); // Validate matrix multiply - ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(a, b, (add_c && fuse_add) ? c : nullptr, output, alpha, beta, - false, reshape_info, gpu_target, gemm_info.fp_mixed_precision())); - - if(add_c && !fuse_add) - { - // Validate matrix addition kernel - ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixAdditionKernel::validate(c, output, beta)); - } + ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(a, b, c, output, alpha, beta, + false, reshape_info, gpu_target, gemm_info.fp_mixed_precision(), gemm_info.activation_info())); return Status{}; } @@ -340,9 +333,6 @@ Status CLGEMM::validate_reshaped_v1(const ITensorInfo *a, const ITensorInfo *b, int mult_transpose1xW_width = 1; int mult_interleave4x4_height = 1; const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); - const bool add_c = (beta != 0.f && c != nullptr); - const bool is_beta_one = std::abs(1.0f - beta) < 0.00001f; - const bool fuse_add = is_beta_one && (c != nullptr && c->num_dimensions() == 1); if(get_arch_from_target(gpu_target) == GPUTarget::BIFROST) { @@ -364,7 +354,7 @@ Status CLGEMM::validate_reshaped_v1(const ITensorInfo *a, const ITensorInfo *b, lhs_info.interleave = true; lhs_info.transpose = true; - const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false); + const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false, gemm_info.broadcast_bias()); // Validate interleave kernel auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d()))); @@ -375,14 +365,8 @@ Status CLGEMM::validate_reshaped_v1(const ITensorInfo *a, const ITensorInfo *b, ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info)); // Validate matrix multiply - ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(&tmp_a_info, &tmp_b_info, (add_c && fuse_add) ? c : nullptr, output, alpha, beta, - true, reshape_info, gpu_target, gemm_info.fp_mixed_precision())); - - if(add_c && !fuse_add) - { - // Validate matrix addition kernel - ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixAdditionKernel::validate(c, output, beta)); - } + ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(&tmp_a_info, &tmp_b_info, c, output, alpha, beta, + true, reshape_info, gpu_target, gemm_info.fp_mixed_precision(), gemm_info.activation_info())); return Status{}; } @@ -405,13 +389,15 @@ Status CLGEMM::validate_reshaped_v2(const ITensorInfo *a, const ITensorInfo *b, const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2); const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); const bool broadcast_bias = gemm_info.broadcast_bias(); - GEMMKernelInfo kernel_info; + + GEMMKernelInfo kernel_info; kernel_info.m = m; kernel_info.n = n; kernel_info.k = k; kernel_info.depth_output_gemm3d = depth_output_gemm3d; kernel_info.reinterpret_input_as_3d = false; kernel_info.broadcast_bias = broadcast_bias; + kernel_info.activation_info = gemm_info.activation_info(); GEMMLHSMatrixInfo lhs_info; GEMMRHSMatrixInfo rhs_info; @@ -452,13 +438,15 @@ Status CLGEMM::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInf const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2); const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); const bool broadcast_bias = gemm_info.broadcast_bias(); - GEMMKernelInfo kernel_info; + + GEMMKernelInfo kernel_info; kernel_info.m = m; kernel_info.n = n; kernel_info.k = k; kernel_info.depth_output_gemm3d = depth_output_gemm3d; kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d; kernel_info.broadcast_bias = broadcast_bias; + kernel_info.activation_info = gemm_info.activation_info(); GEMMLHSMatrixInfo lhs_info; GEMMRHSMatrixInfo rhs_info; @@ -501,9 +489,7 @@ void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor * // Select GEMMType _gemm_type = select_gemm_type(m, n, k, a->info()->data_type(), _reshape_b_only_on_first_run, gpu_target); - const bool is_fuse_add_c_supported = (_gemm_type == GEMMType::RESHAPED_V2) || (_gemm_type == GEMMType::RESHAPED_ONLY_RHS); - const bool add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr); - const bool fuse_add_c = add_c && is_fuse_add_c_supported; + const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr); const ICLTensor *c_to_use = fuse_add_c ? c : nullptr; @@ -534,13 +520,6 @@ void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor * ARM_COMPUTE_ERROR("GEMMType not supported"); } } - - // Configure matrix addition kernel - if(add_c && !fuse_add_c) - { - _ma_kernel.configure(c, output, beta); - _run_addition = true; - } } Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info) @@ -555,9 +534,7 @@ Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso // Select GEMMType GEMMType gemm_type = select_gemm_type(m, n, k, a->data_type(), gemm_info.reshape_b_only_on_first_run(), gpu_target); - const bool is_fuse_add_c_supported = (gemm_type == GEMMType::RESHAPED_V2) || (gemm_type == GEMMType::RESHAPED_ONLY_RHS); - const bool add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr); - const bool fuse_add_c = add_c && is_fuse_add_c_supported; + const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr); const ITensorInfo *c_to_use = fuse_add_c ? c : nullptr; @@ -589,12 +566,6 @@ Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso } } - // Validate matrix addition kernel - if(add_c && !fuse_add_c) - { - ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixAdditionKernel::validate(c, output, beta)); - } - return Status{}; } @@ -609,7 +580,7 @@ void CLGEMM::run() { case GEMMType::NATIVE: { - CLScheduler::get().enqueue(_mm_kernel, !_run_addition); + CLScheduler::get().enqueue(_mm_kernel, true); break; } case GEMMType::RESHAPED_V1: @@ -623,7 +594,7 @@ void CLGEMM::run() CLScheduler::get().enqueue(_reshape_rhs_kernel, false); } - CLScheduler::get().enqueue(_mm_kernel, !_run_addition); + CLScheduler::get().enqueue(_mm_kernel, true); break; } case GEMMType::RESHAPED_V2: @@ -637,7 +608,7 @@ void CLGEMM::run() CLScheduler::get().enqueue(_reshape_rhs_kernel, false); } - CLScheduler::get().enqueue(_mm_reshaped_kernel, !_run_addition); + CLScheduler::get().enqueue(_mm_reshaped_kernel, true); break; } case GEMMType::RESHAPED_ONLY_RHS: @@ -648,7 +619,7 @@ void CLGEMM::run() CLScheduler::get().enqueue(_reshape_rhs_kernel, false); } - CLScheduler::get().enqueue(_mm_reshaped_only_rhs_kernel, !_run_addition); + CLScheduler::get().enqueue(_mm_reshaped_only_rhs_kernel, true); break; } default: @@ -656,12 +627,6 @@ void CLGEMM::run() ARM_COMPUTE_ERROR("GEMMType not supported"); } } - - // Run matrix addition kernel - if(_run_addition) - { - CLScheduler::get().enqueue(_ma_kernel); - } } void CLGEMM::prepare() diff --git a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp index 99f045a0bf..be6be04703 100644 --- a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp +++ b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp @@ -91,22 +91,27 @@ void CLConvolutionLayerReshapeWeights::run() } CLGEMMConvolutionLayer::CLGEMMConvolutionLayer(std::shared_ptr memory_manager) - : _memory_group(memory_manager), _reshape_weights(), _im2col_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _col2im_kernel(), _activationlayer_function(), _add_bias_kernel(), - _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _data_layout(DataLayout::NCHW), _append_bias(false), _skip_im2col(false), _skip_col2im(false), _is_quantized(false), - _is_activationlayer_enabled(false), _is_prepared(false), _run_addition(true) + : _memory_group(memory_manager), _reshape_weights(), _im2col_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _col2im_kernel(), _activationlayer_function(), + _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _skip_im2col(false), _skip_col2im(false), _is_quantized(false), _fuse_activation(true), _is_prepared(false) { } void CLGEMMConvolutionLayer::configure_mm(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const GEMMLowpOutputStageInfo &gemmlowp_output_stage, - int gemm_3d_depth) + int gemm_3d_depth, const ActivationLayerInfo &act_info) { ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights); - ARM_COMPUTE_ERROR_THROW_ON(validate_mm(input->info(), weights->info(), biases != nullptr ? biases->info() : nullptr, output->info(), gemmlowp_output_stage, gemm_3d_depth, _skip_im2col, - _run_addition)); - - const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */, - gemm_3d_depth, _skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, - false, gemmlowp_output_stage); + ARM_COMPUTE_ERROR_THROW_ON(validate_mm(input->info(), weights->info(), biases != nullptr ? biases->info() : nullptr, output->info(), gemmlowp_output_stage, gemm_3d_depth, _skip_im2col, act_info)); + + const GEMMInfo &gemm_info = GEMMInfo(false, // is_a_reshaped + false, // is_b_reshaped + true, // reshape_b_only_on_first_run + gemm_3d_depth, // depth_output_gemm3d + _skip_im2col, // reinterpret_input_as_3d + false, // retain_internal_weights + gemmlowp_output_stage, // gemmlowp_output_stage + false, // fp_mixed_precision + true, // broadcast_bias + act_info); // activation_info if(_is_quantized) { @@ -126,21 +131,26 @@ void CLGEMMConvolutionLayer::configure_mm(const ICLTensor *input, const ICLTenso } else { - // Bias does not need to be added in GEMM if im2col is being used or the Matrix Addition kernel needs to be run - const bool skip_bias_in_gemm = _run_addition || !_skip_im2col; // Configure matrix multiply function - _mm_gemm.configure(input, weights, (skip_bias_in_gemm) ? nullptr : biases, output, 1.0f, 1.0f, gemm_info); + _mm_gemm.configure(input, weights, biases, output, 1.0f, 1.0f, gemm_info); } } Status CLGEMMConvolutionLayer::validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, - const GEMMLowpOutputStageInfo &gemmlowp_output_stage, int gemm_3d_depth, bool skip_im2col, bool run_addition) + const GEMMLowpOutputStageInfo &gemmlowp_output_stage, int gemm_3d_depth, bool skip_im2col, const ActivationLayerInfo &act_info) { const bool is_quantized = is_data_type_quantized_asymmetric(input->data_type()); - const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */, - gemm_3d_depth, skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, - false, gemmlowp_output_stage); + const GEMMInfo &gemm_info = GEMMInfo(false, // is_a_reshaped + false, // is_b_reshaped + true, // reshape_b_only_on_first_run + gemm_3d_depth, // depth_output_gemm3d + skip_im2col, // reinterpret_input_as_3d + false, // retain_internal_weights + gemmlowp_output_stage, // gemmlowp_output_stage + false, // fp_mixed_precision + true, // broadcast_bias + act_info); // activation_info if(is_quantized) { @@ -159,10 +169,8 @@ Status CLGEMMConvolutionLayer::validate_mm(const ITensorInfo *input, const ITens } else { - // Bias does not need to be added in GEMM if im2col is being used or the Matrix Addition kernel needs to be run - const bool skip_bias_in_gemm = run_addition || !skip_im2col; // Perform validation step on Matrix multiply function - return CLGEMM::validate(input, weights, (skip_bias_in_gemm) ? nullptr : biases, output, 1.0f, 1.0f, gemm_info); + return CLGEMM::validate(input, weights, biases, output, 1.0f, 1.0f, gemm_info); } } @@ -194,15 +202,14 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * const UniformQuantizationInfo wq_info = weights->info()->quantization_info().uniform(); const UniformQuantizationInfo oq_info = output->info()->quantization_info().uniform(); - _is_prepared = weights_info.retain_internal_weights(); - _original_weights = weights; - _is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type()); - _data_layout = data_layout; - _skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1); - _skip_col2im = data_layout == DataLayout::NHWC; - _append_bias = (biases != nullptr) && (!_is_quantized); - _is_activationlayer_enabled = act_info.enabled(); - _run_addition = (_skip_im2col) && (_append_bias); + _is_prepared = weights_info.retain_internal_weights(); + _original_weights = weights; + _is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type()); + _skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1); + _skip_col2im = data_layout == DataLayout::NHWC; + + // Only for quantize there are few cases where we cannot fuse the activation function in GEMM + _fuse_activation = true; // Set the GPU target for im2col and col2im _im2col_kernel.set_target(CLScheduler::get().target()); @@ -211,8 +218,6 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * const ICLTensor *gemm_input_to_use = input; ICLTensor *gemm_output_to_use = output; - const ICLTensor *biases_to_use = (_append_bias && !_skip_im2col) ? biases : nullptr; - // Get parameters from conv_info unsigned int stride_x = 0; unsigned int stride_y = 0; @@ -230,9 +235,22 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * unsigned int mat_weights_cols = weights->info()->dimension(idx_kernels) / num_groups; - // _weights_reshaped will be auto configured in the kernel. - // Just append biases and do not transpose 1xW as it will be reshaped in CLGEMM - _reshape_weights.configure(weights, biases_to_use, &_weights_reshaped, num_groups); + const ICLTensor *biases_to_use = biases; + bool append_bias = false; + + if(num_groups != 1 && biases != nullptr) + { + // num_groups != 1 can only be for NCHW + // Since it is missing an utility function to reshape the biases, we append the biases into the weights tensor + biases_to_use = nullptr; + append_bias = true; + + _reshape_weights.configure(weights, biases, &_weights_reshaped, num_groups); + } + else + { + _reshape_weights.configure(weights, nullptr, &_weights_reshaped, num_groups); + } // Create tensor to store im2col reshaped inputs if(!_skip_im2col) @@ -240,7 +258,7 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * _memory_group.manage(&_im2col_output); // Configure and tune im2col. im2col output shape is auto-initialized - _im2col_kernel.configure(input, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, _append_bias, dilation, num_groups); + _im2col_kernel.configure(input, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation, num_groups); // Set quantization info _im2col_output.info()->set_quantization_info(input->info()->quantization_info()); @@ -249,11 +267,6 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * // Update GEMM input gemm_input_to_use = &_im2col_output; } - else if(_append_bias) - { - // Configure add bias kernel - _add_bias_kernel.configure(ArithmeticOperation::ADD, output, biases, output, ConvertPolicy::SATURATE); - } // Create GEMM output tensor if(!_skip_col2im) @@ -299,16 +312,20 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU }; - if(_is_activationlayer_enabled && supported_acts.count(act_info.activation()) != 0) + if(act_info.enabled()) { - const int a_const_int = quantize_qasymm8(act_info.a(), output_quant_info); - const int b_const_int = quantize_qasymm8(act_info.b(), output_quant_info); - - min_activation = act_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU ? output_quant_info.offset : b_const_int; - max_activation = act_info.activation() == ActivationLayerInfo::ActivationFunction::RELU ? 255 : a_const_int; - - // If the activation layer is RELU, BOUNDED_RELU or LU_BOUNDED_RELU, we can use the GEMMLowp output stage to perform this operation - _is_activationlayer_enabled = false; + if(supported_acts.count(act_info.activation()) != 0) + { + const int a_const_int = quantize_qasymm8(act_info.a(), output_quant_info); + const int b_const_int = quantize_qasymm8(act_info.b(), output_quant_info); + + min_activation = act_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU ? output_quant_info.offset : b_const_int; + max_activation = act_info.activation() == ActivationLayerInfo::ActivationFunction::RELU ? 255 : a_const_int; + } + else + { + _fuse_activation = false; + } } // Set the GEMMLowp output stage info @@ -323,7 +340,7 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * // In case of NHWC, we need to run GEMM3D (gemm_3d_depth != 0) in order to avoid reshaping the output matrix const unsigned int gemm_3d_depth = (data_layout == DataLayout::NHWC) ? conv_h : 0; - configure_mm(gemm_input_to_use, &_weights_reshaped, biases, gemm_output_to_use, gemmlowp_output_stage, gemm_3d_depth); + configure_mm(gemm_input_to_use, &_weights_reshaped, biases_to_use, gemm_output_to_use, gemmlowp_output_stage, gemm_3d_depth, act_info); if(!_skip_im2col) { @@ -345,7 +362,7 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * ARM_COMPUTE_ERROR_ON_MSG((output->info()->dimension(idx_width) != conv_w) || (output->info()->dimension(idx_height) != conv_h), "Output shape does not match the expected one"); - if(_is_activationlayer_enabled) + if(!_fuse_activation) { _activationlayer_function.configure(output, nullptr, act_info); } @@ -382,12 +399,10 @@ Status CLGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI const ITensorInfo *gemm_output_to_use = output; const ITensorInfo *weights_to_use = weights; - const bool is_quantized = is_data_type_quantized_asymmetric(data_type); - const bool append_bias = (biases != nullptr) && (!is_quantized); - const bool skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1); - const bool skip_col2im = data_layout == DataLayout::NHWC; - bool is_activationlayer_enabled = act_info.enabled(); - const bool run_addition = (skip_im2col) && (append_bias); + const bool is_quantized = is_data_type_quantized_asymmetric(data_type); + const bool skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1); + const bool skip_col2im = data_layout == DataLayout::NHWC; + bool fuse_activation = true; const UniformQuantizationInfo iq_info = input->quantization_info().uniform(); const UniformQuantizationInfo wq_info = weights->quantization_info().uniform(); @@ -429,10 +444,26 @@ Status CLGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI unsigned int mat_weights_cols = weights->dimension(idx_kernels) / num_groups; - // Output tensor auto inizialitation if not yet initialized - ARM_COMPUTE_RETURN_ON_ERROR(CLConvolutionLayerReshapeWeights::validate(weights, is_quantized ? nullptr : biases, nullptr, num_groups)); - weights_reshaped_info = TensorInfo(compute_weights_reshaped_shape(*weights, (append_bias && !skip_im2col), num_groups), 1, data_type); - weights_to_use = &weights_reshaped_info; + const ITensorInfo *biases_to_use = biases; + bool append_bias = false; + + if(num_groups != 1 && biases != nullptr) + { + // num_groups != 1 can only be for NCHW + // Since it is missing an utility function to reshape the biases, we append the biases into the weights tensor + biases_to_use = nullptr; + append_bias = true; + + ARM_COMPUTE_RETURN_ON_ERROR(CLConvolutionLayerReshapeWeights::validate(weights, biases, nullptr, num_groups)); + weights_reshaped_info = TensorInfo(compute_weights_reshaped_shape(*weights, true, num_groups), 1, data_type); + } + else + { + ARM_COMPUTE_RETURN_ON_ERROR(CLConvolutionLayerReshapeWeights::validate(weights, nullptr, nullptr, num_groups)); + weights_reshaped_info = TensorInfo(compute_weights_reshaped_shape(*weights, false, num_groups), 1, data_type); + } + + weights_to_use = &weights_reshaped_info; if(!skip_im2col) { @@ -446,11 +477,6 @@ Status CLGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI ARM_COMPUTE_RETURN_ON_ERROR(CLIm2ColKernel::validate(input, &im2col_reshaped_info, kernel_dims, conv_info, append_bias, dilation, num_groups)); gemm_input_to_use = &im2col_reshaped_info; } - else if(run_addition) - { - // Validate add bias kernel - ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, output, biases, output, ConvertPolicy::SATURATE)); - } // Create GEMM output tensor if(!skip_col2im) @@ -490,16 +516,20 @@ Status CLGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU }; - if(is_activationlayer_enabled && supported_acts.count(act_info.activation()) != 0) + if(act_info.enabled()) { - const int a_const_int = quantize_qasymm8(act_info.a(), output_quant_info); - const int b_const_int = quantize_qasymm8(act_info.b(), output_quant_info); - - min_activation = act_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU ? output_quant_info.offset : b_const_int; - max_activation = act_info.activation() == ActivationLayerInfo::ActivationFunction::RELU ? 255 : a_const_int; - - // If the activation layer is RELU, BOUNDED_RELU or LU_BOUNDED_RELU, we can use the GEMMLowp output stage to perform this operation - is_activationlayer_enabled = false; + if(supported_acts.count(act_info.activation()) != 0) + { + const int a_const_int = quantize_qasymm8(act_info.a(), output_quant_info); + const int b_const_int = quantize_qasymm8(act_info.b(), output_quant_info); + + min_activation = act_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU ? output_quant_info.offset : b_const_int; + max_activation = act_info.activation() == ActivationLayerInfo::ActivationFunction::RELU ? 255 : a_const_int; + } + else + { + fuse_activation = false; + } } // Set the GEMMLowp output stage info @@ -513,7 +543,7 @@ Status CLGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI // In case of NHWC, we need to run GEMM3D (gemm_3d_depth != 0) in order to avoid reshaping the output matrix const unsigned int gemm_3d_depth = (data_layout == DataLayout::NHWC) ? conv_h : 0; - ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, biases, gemm_output_to_use, gemmlowp_output_stage, gemm_3d_depth, skip_im2col, run_addition)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, biases_to_use, gemm_output_to_use, gemmlowp_output_stage, gemm_3d_depth, skip_im2col, act_info)); // Validate Col2Im if(!skip_col2im) @@ -522,7 +552,7 @@ Status CLGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI } //Validate Activation Layer - if(is_activationlayer_enabled) + if(!fuse_activation) { ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(output, nullptr, act_info)); } @@ -554,19 +584,14 @@ void CLGEMMConvolutionLayer::run() _mm_gemm.run(); } - if(_run_addition) - { - CLScheduler::get().enqueue(_add_bias_kernel); - } - // Reshape output matrix if(!_skip_col2im) { CLScheduler::get().enqueue(_col2im_kernel, false); } - //Run Activation Layer if enabled - if(_is_activationlayer_enabled) + //Run Activation Layer if we cannot fuse in GEMM + if(!_fuse_activation) { _activationlayer_function.run(); } diff --git a/tests/datasets/LargeGEMMDataset.h b/tests/datasets/LargeGEMMDataset.h index 0876ae1d2c..0ca0b04460 100644 --- a/tests/datasets/LargeGEMMDataset.h +++ b/tests/datasets/LargeGEMMDataset.h @@ -55,13 +55,13 @@ class LargeGEMMOutput3DDataset final : public GEMMDataset public: LargeGEMMOutput3DDataset() { - add_config(TensorShape(923U, 429U), TensorShape(871U, 923U), TensorShape(871U, 143U, 3U), TensorShape(871U, 143U, 3U), 1.0f, 0.0f); - add_config(TensorShape(681U, 1025U), TensorShape(213U, 681U), TensorShape(213U, 205U, 5U), TensorShape(213U, 205U, 5U), 1.0f, 0.0f); - add_config(TensorShape(364U, 3025U), TensorShape(96U, 364U), TensorShape(96U, 605U, 5U), TensorShape(96U, 605U, 5U), 1.0f, 0.0f); - add_config(TensorShape(1201U, 729U), TensorShape(128U, 1201U), TensorShape(128U, 243U, 3U), TensorShape(128U, 243U, 3U), 1.0f, 0.0f); - add_config(TensorShape(2305U, 169U), TensorShape(384U, 2305U), TensorShape(384U, 13U, 13U), TensorShape(384U, 13U, 13U), 1.0f, 0.0f); - add_config(TensorShape(1729U, 170U), TensorShape(192U, 1729U), TensorShape(192U, 85U, 2U), TensorShape(192U, 85U, 2U), 1.0f, 0.0f); - add_config(TensorShape(1729U, 170U), TensorShape(128U, 1729U), TensorShape(128U, 17U, 10U), TensorShape(128U, 17U, 10U), 1.0f, 0.0f); + add_config(TensorShape(923U, 429U), TensorShape(871U, 923U), TensorShape(871U), TensorShape(871U, 143U, 3U), 1.0f, 0.0f); + add_config(TensorShape(681U, 1025U), TensorShape(213U, 681U), TensorShape(213U), TensorShape(213U, 205U, 5U), 1.0f, 0.0f); + add_config(TensorShape(364U, 3025U), TensorShape(96U, 364U), TensorShape(96U), TensorShape(96U, 605U, 5U), 1.0f, 0.0f); + add_config(TensorShape(1201U, 729U), TensorShape(128U, 1201U), TensorShape(128U), TensorShape(128U, 243U, 3U), 1.0f, 0.0f); + add_config(TensorShape(2305U, 169U), TensorShape(384U, 2305U), TensorShape(384U), TensorShape(384U, 13U, 13U), 1.0f, 0.0f); + add_config(TensorShape(1729U, 170U), TensorShape(192U, 1729U), TensorShape(192U), TensorShape(192U, 85U, 2U), 1.0f, 0.0f); + add_config(TensorShape(1729U, 170U), TensorShape(128U, 1729U), TensorShape(128U), TensorShape(128U, 17U, 10U), 1.0f, 0.0f); } }; @@ -70,13 +70,13 @@ class LargeGEMMInputOutput3DDataset final : public GEMMDataset public: LargeGEMMInputOutput3DDataset() { - add_config(TensorShape(923U, 143U, 3U), TensorShape(871U, 923U), TensorShape(871U, 143U, 3U), TensorShape(871U, 143U, 3U), 1.0f, 0.0f); - add_config(TensorShape(681U, 205U, 5U), TensorShape(213U, 681U), TensorShape(213U, 205U, 5U), TensorShape(213U, 205U, 5U), 1.0f, 0.0f); - add_config(TensorShape(364U, 605U, 5U), TensorShape(96U, 364U), TensorShape(96U, 605U, 5U), TensorShape(96U, 605U, 5U), 0.2f, 1.2f); - add_config(TensorShape(1201U, 243U, 3U), TensorShape(128U, 1201U), TensorShape(128U, 243U, 3U), TensorShape(128U, 243U, 3U), 1.0f, 0.0f); - add_config(TensorShape(2305U, 13U, 13U), TensorShape(384U, 2305U), TensorShape(384U, 13U, 13U), TensorShape(384U, 13U, 13U), 0.4f, 0.7f); - add_config(TensorShape(1729U, 85U, 2U, 2U), TensorShape(192U, 1729U), TensorShape(192U, 85U, 2U, 2U), TensorShape(192U, 85U, 2U, 2U), 1.0f, 0.0f); - add_config(TensorShape(1729U, 17U, 10U, 3U), TensorShape(128U, 1729U), TensorShape(128U, 17U, 10U, 3U), TensorShape(128U, 17U, 10U, 3U), 1.0f, 0.3f); + add_config(TensorShape(923U, 143U, 3U), TensorShape(871U, 923U), TensorShape(871U), TensorShape(871U, 143U, 3U), 1.0f, 0.0f); + add_config(TensorShape(681U, 205U, 5U), TensorShape(213U, 681U), TensorShape(213U), TensorShape(213U, 205U, 5U), 1.0f, 0.0f); + add_config(TensorShape(364U, 605U, 5U), TensorShape(96U, 364U), TensorShape(96U), TensorShape(96U, 605U, 5U), 0.2f, 1.2f); + add_config(TensorShape(1201U, 243U, 3U), TensorShape(128U, 1201U), TensorShape(128U), TensorShape(128U, 243U, 3U), 1.0f, 0.0f); + add_config(TensorShape(2305U, 13U, 13U), TensorShape(384U, 2305U), TensorShape(384U), TensorShape(384U, 13U, 13U), 0.4f, 0.7f); + add_config(TensorShape(1729U, 85U, 2U, 2U), TensorShape(192U, 1729U), TensorShape(192U), TensorShape(192U, 85U, 2U, 2U), 1.0f, 0.0f); + add_config(TensorShape(1729U, 17U, 10U, 3U), TensorShape(128U, 1729U), TensorShape(128U), TensorShape(128U, 17U, 10U, 3U), 1.0f, 0.3f); } }; } // namespace datasets diff --git a/tests/datasets/SmallGEMMDataset.h b/tests/datasets/SmallGEMMDataset.h index ae3c3ed86d..45d1a1e07e 100644 --- a/tests/datasets/SmallGEMMDataset.h +++ b/tests/datasets/SmallGEMMDataset.h @@ -55,12 +55,12 @@ class SmallGEMMOutput3DDataset final : public GEMMDataset public: SmallGEMMOutput3DDataset() { - add_config(TensorShape(21U, 14U), TensorShape(34U, 21U), TensorShape(34U, 7U, 2U), TensorShape(34U, 7U, 2U), 1.0f, 0.0f); - add_config(TensorShape(31U, 1U), TensorShape(23U, 31U), TensorShape(23U, 1U, 1U), TensorShape(23U, 1U, 1U), 1.0f, 0.0f); - add_config(TensorShape(38U, 12U), TensorShape(21U, 38U), TensorShape(21U, 4U, 3U), TensorShape(21U, 4U, 3U), 0.2f, 1.2f); - add_config(TensorShape(32U, 1U), TensorShape(17U, 32U), TensorShape(17U, 1U, 1U), TensorShape(17U, 1U, 1U), 0.4f, 0.7f); - add_config(TensorShape(16U, 16U), TensorShape(8U, 16U), TensorShape(8U, 8U, 2U), TensorShape(8U, 8U, 2U), 1.0f, 0.0f); - add_config(TensorShape(16U, 16U, 5U), TensorShape(8U, 16U, 5U), TensorShape(8U, 8U, 2U, 5U), TensorShape(8U, 8U, 2U, 5U), 1.0f, 0.0f); + add_config(TensorShape(21U, 14U), TensorShape(34U, 21U), TensorShape(34U), TensorShape(34U, 7U, 2U), 1.0f, 0.0f); + add_config(TensorShape(31U, 1U), TensorShape(23U, 31U), TensorShape(23U), TensorShape(23U, 1U, 1U), 1.0f, 0.0f); + add_config(TensorShape(38U, 12U), TensorShape(21U, 38U), TensorShape(21U), TensorShape(21U, 4U, 3U), 0.2f, 1.2f); + add_config(TensorShape(32U, 1U), TensorShape(17U, 32U), TensorShape(17U), TensorShape(17U, 1U, 1U), 0.4f, 0.7f); + add_config(TensorShape(16U, 16U), TensorShape(8U, 16U), TensorShape(8U), TensorShape(8U, 8U, 2U), 1.0f, 0.0f); + add_config(TensorShape(16U, 16U, 5U), TensorShape(8U, 16U, 5U), TensorShape(8U), TensorShape(8U, 8U, 2U, 5U), 1.0f, 0.0f); } }; @@ -69,12 +69,12 @@ class SmallGEMMInputOutput3DDataset final : public GEMMDataset public: SmallGEMMInputOutput3DDataset() { - add_config(TensorShape(21U, 14U, 13U), TensorShape(34U, 21U), TensorShape(34U, 14U, 13U), TensorShape(34U, 14U, 13U), 1.0f, 0.0f); - add_config(TensorShape(31U, 1U, 3U), TensorShape(23U, 31U), TensorShape(23U, 1U, 3U), TensorShape(23U, 1U, 3U), 1.0f, 0.0f); - add_config(TensorShape(38U, 12U, 2U), TensorShape(21U, 38U), TensorShape(21U, 12U, 2U), TensorShape(21U, 12U, 2U), 0.2f, 1.2f); - add_config(TensorShape(32U, 1U, 4U, 3U), TensorShape(17U, 32U), TensorShape(17U, 1U, 4U, 3U), TensorShape(17U, 1U, 4U, 3U), 0.4f, 0.7f); - add_config(TensorShape(16U, 16U, 3U, 2U), TensorShape(8U, 16U), TensorShape(8U, 16U, 3U, 2U), TensorShape(8U, 16U, 3U, 2U), 1.0f, 0.0f); - add_config(TensorShape(16U, 16U, 5U, 3U), TensorShape(8U, 16U), TensorShape(8U, 16U, 5U, 3U), TensorShape(8U, 16U, 5U, 3U), 1.0f, 0.3f); + add_config(TensorShape(21U, 14U, 13U), TensorShape(34U, 21U), TensorShape(34U), TensorShape(34U, 14U, 13U), 1.0f, 0.0f); + add_config(TensorShape(31U, 1U, 3U), TensorShape(23U, 31U), TensorShape(23U), TensorShape(23U, 1U, 3U), 1.0f, 0.0f); + add_config(TensorShape(38U, 12U, 2U), TensorShape(21U, 38U), TensorShape(21U), TensorShape(21U, 12U, 2U), 0.2f, 1.2f); + add_config(TensorShape(32U, 1U, 4U, 3U), TensorShape(17U, 32U), TensorShape(17U), TensorShape(17U, 1U, 4U, 3U), 0.4f, 0.7f); + add_config(TensorShape(16U, 16U, 3U, 2U), TensorShape(8U, 16U), TensorShape(8U), TensorShape(8U, 16U, 3U, 2U), 1.0f, 0.0f); + add_config(TensorShape(16U, 16U, 5U, 3U), TensorShape(8U, 16U), TensorShape(8U), TensorShape(8U, 16U, 5U, 3U), 1.0f, 0.3f); } }; } // namespace datasets diff --git a/tests/validation/CL/GEMMMatrixMultiply.cpp b/tests/validation/CL/GEMMMatrixMultiply.cpp index 21fd7125ec..8f7c0aaef1 100644 --- a/tests/validation/CL/GEMMMatrixMultiply.cpp +++ b/tests/validation/CL/GEMMMatrixMultiply.cpp @@ -67,7 +67,7 @@ RelativeTolerance rel_tolerance_f16(half(0.2)); constexpr float tolerance_num_f16 = 0.02f; /** Alpha values to test - Precommit */ -const auto alpha_values = framework::dataset::make("alpha", {0.0f, 1.0f, -0.75f} ); +const auto alpha_values = framework::dataset::make("alpha", {1.0f, -0.75f} ); /** Beta values to test - Precommit */ const auto beta_values = framework::dataset::make("beta", {-0.75f, 0.0f} ); diff --git a/tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp b/tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp index cae94b2e15..5d21cf4f34 100644 --- a/tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp +++ b/tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp @@ -77,7 +77,7 @@ RelativeTolerance rel_tolerance_f16(half(0.2)); constexpr float tolerance_num_f16 = 0.02f; /** Alpha values to test - Precommit */ -const auto alpha_values = framework::dataset::make("alpha", {0.0f, 1.0f, -0.75f} ); +const auto alpha_values = framework::dataset::make("alpha", {1.0f, -0.75f} ); /** Beta values to test - Precommit */ const auto beta_values = framework::dataset::make("beta", {-0.75f, 0.0f} ); diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h index b36bb99246..a04a901b1c 100644 --- a/tests/validation/fixtures/GEMMFixture.h +++ b/tests/validation/fixtures/GEMMFixture.h @@ -44,7 +44,7 @@ namespace test { namespace validation { -template +template class GEMMValidationFixture : public framework::Fixture { public: @@ -87,7 +87,13 @@ protected: // The GEMMinfo includes the values of the depth in case of reinterpreted 3d output. // If the output shape has the same number of dimensions of the input the method called is a 2D matrix multiplication (depth_output_reinterpreted_as_3D = 0), // in the other case we have to use the reinterpreted version of GEMM (depth_output_reinterpreted_as_3D = depth of the 3D output). - gemm.configure(&a, &b, (disable_c) ? nullptr : &c, &dst, alpha, beta, GEMMInfo(false, false, false, (reinterpret_ouput_as_3d ? output_shape[2] : 0), reinterpret_input_as_3d)); + gemm.configure(&a, + &b, + (disable_c) ? nullptr : &c, + &dst, + alpha, beta, + GEMMInfo(false, false, false, (reinterpret_output_as_3d ? output_shape[2] : 0), reinterpret_input_as_3d, false, GEMMLowpOutputStageInfo(), false, (reinterpret_input_as_3d + || reinterpret_output_as_3d))); ARM_COMPUTE_EXPECT(a.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(b.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(c.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -122,6 +128,7 @@ protected: DataType data_type) { TensorShape shape_a_to_use = shape_a; + if(reinterpret_input_as_3d) { // Collapse the second and third dimension if the input is 3D @@ -131,22 +138,29 @@ protected: // Create reference SimpleTensor a{ shape_a_to_use, data_type, 1 }; SimpleTensor b{ shape_b, data_type, 1 }; - SimpleTensor c{ shape_c, data_type, 1 }; + SimpleTensor c{ output_shape, data_type, 1 }; // Fill reference fill(a, 0); fill(b, 1); - if(!disable_c) - { - fill(c, 2); - return reference::gemm(a, b, c, alpha, beta); - } - else + fill(c, 2); + + if(reinterpret_input_as_3d || reinterpret_output_as_3d) { - // Setting beta to 0 will effectively disable C for the - // computation of the reference: alpha * A * B + 0 * C - return reference::gemm(a, b, c, alpha, 0.f); + const int n = shape_b[0]; + const int m = reinterpret_output_as_3d ? output_shape[1] * output_shape[2] : output_shape[1]; + const int batch_size = reinterpret_output_as_3d ? output_shape[3] : output_shape[2]; + + // In case of broadcast, we need simply copy the first into the following "M" ones + for(int i = 1; i < m * batch_size; i++) + { + memcpy(c.data() + i * n, c.data(), n * sizeof(T)); + } } + + // Setting beta to 0 will effectively disable C for the + // computation of the reference: alpha * A * B + 0 * C + return reference::gemm(a, b, c, alpha, disable_c ? 0.f : beta); } TensorType _target{}; -- cgit v1.2.1