diff options
author | Giorgio Arena <giorgio.arena@arm.com> | 2021-05-17 13:03:50 +0100 |
---|---|---|
committer | Giorgio Arena <giorgio.arena@arm.com> | 2021-05-20 15:19:39 +0000 |
commit | 4403ed3ed09491686a0b182fa498344b005ca812 (patch) | |
tree | 5a231a71d70a7b3ae2412729d8f6a170b54510f7 | |
parent | ea8d266515812c4dec936b2153ffd5335873e583 (diff) | |
download | ComputeLibrary-4403ed3ed09491686a0b182fa498344b005ca812.tar.gz |
Add support for dynamic weights in CL FullyConnected layer
Make GEMM use its native version if weights are dynamic. This ensures no reshape gets performed on the weights tensor
Enable dynamic weights tests for the OpenCL backend
Resolve COMPMID-4223
Signed-off-by: Giorgio Arena <giorgio.arena@arm.com>
Change-Id: Iccc4806701772cede23e24df09c786914d00034c
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5652
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
-rw-r--r-- | arm_compute/core/Types.h | 18 | ||||
-rw-r--r-- | src/runtime/CL/functions/CLFullyConnectedLayer.cpp | 4 | ||||
-rw-r--r-- | src/runtime/gpu/cl/operators/ClGemm.cpp | 12 | ||||
-rw-r--r-- | tests/validation/CL/FullyConnectedLayer.cpp | 6 | ||||
-rw-r--r-- | tests/validation/fixtures/FullyConnectedLayerFixture.h | 1 |
5 files changed, 31 insertions, 10 deletions
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index ec9c419dbc..2dc9a77c39 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -1954,7 +1954,8 @@ public: _fp_mixed_precision(false), _broadcast_bias(false), _pretranpose_B(true), - _activation_info() + _activation_info(), + _constant_weights(true) { } /** Constructor @@ -1971,10 +1972,11 @@ public: * @param[in] fp_mixed_precision (Optional) Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy. * @param[in] broadcast_bias (Optional) Broadcast the shape of the bias tensor from a vector to a matrix. * @param[in] activation_info (Optional) Activation to apply after the matrix multiplication + * @param[in] constant_weights (Optional) Weights have constant values throughout multiple executions */ GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 0, bool reinterpret_input_as_3d = false, bool retain_internal_weights = false, GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(), bool fp_mixed_precision = false, bool broadcast_bias = false, - const ActivationLayerInfo &activation_info = ActivationLayerInfo()) noexcept + const ActivationLayerInfo &activation_info = ActivationLayerInfo(), bool constant_weights = true) noexcept : _is_a_reshaped(is_a_reshaped), _is_b_reshaped(is_b_reshaped), _reshape_b_only_on_first_run(reshape_b_only_on_first_run), @@ -1985,7 +1987,8 @@ public: _fp_mixed_precision(fp_mixed_precision), _broadcast_bias(broadcast_bias), _pretranpose_B(reshape_b_only_on_first_run), - _activation_info(activation_info) + _activation_info(activation_info), + _constant_weights(constant_weights) { } /** Flag which specifies if the matrix A has been reshaped @@ -2102,6 +2105,14 @@ public: { _activation_info = activation_info; } + /** Flag which specifies if the values of the weights tensor are constant throughout multiple executions or not + * + * @return True if the weights tensor is constant + */ + bool constant_weights() const + { + return _constant_weights; + }; private: bool _is_a_reshaped; @@ -2115,6 +2126,7 @@ private: bool _broadcast_bias; bool _pretranpose_B; ActivationLayerInfo _activation_info; + bool _constant_weights; }; /** Winograd information */ diff --git a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp index 50a145f9ca..31c8908270 100644 --- a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp +++ b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp @@ -157,7 +157,8 @@ void CLFullyConnectedLayer::configure_mm(const CLCompileContext &compile_context gemmlowp_output_stage, // gemmlowp_output_stage fc_info.fp_mixed_precision, // fp_mixed_precision true, // broadcast_bias - fc_info.activation_info); // activation_info + fc_info.activation_info, // activation_info + fc_info.constant_weights); // constant_weights if(_is_quantized) { @@ -325,6 +326,7 @@ Status CLFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorIn ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 2); ARM_COMPUTE_RETURN_ERROR_ON(fc_info.activation_info.enabled() && is_data_type_quantized(input->data_type()) && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::BOUNDED_RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU); + ARM_COMPUTE_RETURN_ERROR_ON(!fc_info.constant_weights && (!fc_info.are_weights_reshaped || fc_info.transpose_weights)); bool weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true; bool is_fc_after_conv = true; diff --git a/src/runtime/gpu/cl/operators/ClGemm.cpp b/src/runtime/gpu/cl/operators/ClGemm.cpp index fcbc6d5fba..a80375447d 100644 --- a/src/runtime/gpu/cl/operators/ClGemm.cpp +++ b/src/runtime/gpu/cl/operators/ClGemm.cpp @@ -78,8 +78,13 @@ inline bool validate_gemm_kernel(CLGEMMKernelType kernel_type) } } //Automatically select between mlgo (prioritized) and default heuristics for gemm kernel type -inline CLGEMMKernelType auto_select_gemm_kernel(auto_heuristics::CommonQuery query, bool reshape_b_only_on_first_run) +inline CLGEMMKernelType auto_select_gemm_kernel(auto_heuristics::CommonQuery query, bool reshape_b_only_on_first_run, bool constant_weights) { + if(!constant_weights) + { + return CLGEMMKernelType::NATIVE_V1; + } + auto gemm_kernel = auto_heuristics::select_mlgo_gemm_kernel(query, reshape_b_only_on_first_run); if(bool(gemm_kernel)) { @@ -564,7 +569,8 @@ void ClGemm::configure(const CLCompileContext &compile_context, ITensorInfo *a, const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2); // Select GEMMType - _gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery{ CLScheduler::get().target(), a->data_type(), m, n, k, batch_size }, _reshape_b_only_on_first_run); + _gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery{ CLScheduler::get().target(), a->data_type(), m, n, k, batch_size }, _reshape_b_only_on_first_run, + gemm_info.constant_weights()); const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr); @@ -613,7 +619,7 @@ Status ClGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso { CLScheduler::get().target(), a->data_type(), m, n, k, batch_size, }, - gemm_info.reshape_b_only_on_first_run()); + gemm_info.reshape_b_only_on_first_run(), gemm_info.constant_weights()); const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr); diff --git a/tests/validation/CL/FullyConnectedLayer.cpp b/tests/validation/CL/FullyConnectedLayer.cpp index 9fa9eb5eaa..09da519c51 100644 --- a/tests/validation/CL/FullyConnectedLayer.cpp +++ b/tests/validation/CL/FullyConnectedLayer.cpp @@ -172,9 +172,9 @@ FIXTURE_DATA_TEST_CASE(RunMixedDataLayout, CLFullyConnectedLayerMixedDataLayoutF // Validate output validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0, abs_tolerance_f32); } -DISABLED_FIXTURE_DATA_TEST_CASE(RunDynamicWeights, CLFullyConnectedLayerDynamicWeightsFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallFullyConnectedLayerDataset(), - framework::dataset::make("DataType", DataType::F32)), - framework::dataset::make("ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)))) +FIXTURE_DATA_TEST_CASE(RunDynamicWeights, CLFullyConnectedLayerDynamicWeightsFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallFullyConnectedLayerDataset(), + framework::dataset::make("DataType", DataType::F32)), + framework::dataset::make("ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)))) { } FIXTURE_DATA_TEST_CASE(RunLarge, CLFullyConnectedLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeFullyConnectedLayerDataset(), FullyConnectedParameters), diff --git a/tests/validation/fixtures/FullyConnectedLayerFixture.h b/tests/validation/fixtures/FullyConnectedLayerFixture.h index e5fea60923..7d767642f3 100644 --- a/tests/validation/fixtures/FullyConnectedLayerFixture.h +++ b/tests/validation/fixtures/FullyConnectedLayerFixture.h @@ -355,6 +355,7 @@ public: FullyConnectedLayerInfo fc_info; fc_info.activation_info = activation_info; fc_info.are_weights_reshaped = true; + fc_info.transpose_weights = false; fc_info.constant_weights = false; FunctionType fc; fc.configure(&_src, &_weights, &_bias, &_dst, fc_info); |