diff options
author | Michele Di Giorgio <michele.digiorgio@arm.com> | 2018-08-22 14:28:30 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:54:54 +0000 |
commit | ba1ffe96eb4563ba7e18b39728d9db373c62f7c3 (patch) | |
tree | 0e931ef35b271f353d98aec59f9e3042471b4aea /arm_compute | |
parent | 3ada2b7a29e1ab2058ab7dc701cacff548d2aae9 (diff) | |
download | ComputeLibrary-ba1ffe96eb4563ba7e18b39728d9db373c62f7c3.tar.gz |
COMPMID-1537: Fix weights retention in CLFullyConnectedLayer
Change-Id: Id978c34889b86fa8b9184d3349cc9b12837141a2
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/145403
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute')
-rw-r--r-- | arm_compute/core/Types.h | 16 | ||||
-rw-r--r-- | arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h | 6 |
2 files changed, 16 insertions, 6 deletions
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index d9109e4565..37a8850237 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -1141,7 +1141,7 @@ class GEMMInfo public: /** Default constructor */ GEMMInfo() - : _is_a_reshaped(false), _is_b_reshaped(false), _reshape_b_only_on_first_run(false), _depth_output_gemm3d(1), _reinterpret_input_as_3d(false) + : _is_a_reshaped(false), _is_b_reshaped(false), _reshape_b_only_on_first_run(false), _depth_output_gemm3d(1), _reinterpret_input_as_3d(false), _retain_internal_weights(false) { } /** Constructor @@ -1152,11 +1152,12 @@ public: * @param[in] depth_output_gemm3d (Optional) Depth (third dimension) of the output tensor to be used with the GEMM3D kernel * @param[in] reinterpret_input_as_3d (Optional) Reinterpret the input as 3D tensor. (i.e. this flag should be set to true when GEMM is used * to perform 1x1 convolutions with the NHWC data layout) + * @param[in] retain_internal_weights (Optional) Retain the weights tensor from previous run * */ - GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 1, bool reinterpret_input_as_3d = false) + GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 1, bool reinterpret_input_as_3d = false, bool retain_internal_weights = false) : _is_a_reshaped(is_a_reshaped), _is_b_reshaped(is_b_reshaped), _reshape_b_only_on_first_run(reshape_b_only_on_first_run), _depth_output_gemm3d(depth_output_gemm3d), - _reinterpret_input_as_3d(reinterpret_input_as_3d) + _reinterpret_input_as_3d(reinterpret_input_as_3d), _retain_internal_weights(retain_internal_weights) { } /** Flag which specifies if the matrix A has been reshaped @@ -1201,6 +1202,14 @@ public: { return _reinterpret_input_as_3d; }; + /** Flag which specifies if the weights tensor has to be retained from previous run + * + * @return True if the weights tensor has to be retained + */ + bool retain_internal_weights() const + { + return _retain_internal_weights; + }; private: const bool _is_a_reshaped; @@ -1208,6 +1217,7 @@ private: const bool _reshape_b_only_on_first_run; const int _depth_output_gemm3d; const bool _reinterpret_input_as_3d; + const bool _retain_internal_weights; }; /** Winograd information */ diff --git a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h index 450cd831ee..d6d88cec55 100644 --- a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h +++ b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h @@ -125,9 +125,9 @@ public: void prepare() override; private: - void configure_fc_fc(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output); - void configure_conv_fc(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output); - void configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output); + void configure_fc_fc(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, bool retain_internal_weights); + void configure_conv_fc(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, bool retain_internal_weights); + void configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, bool retain_internal_weights); CLMemoryGroup _memory_group; CLConvertFullyConnectedWeights _convert_weights; |