From ba1ffe96eb4563ba7e18b39728d9db373c62f7c3 Mon Sep 17 00:00:00 2001 From: Michele Di Giorgio Date: Wed, 22 Aug 2018 14:28:30 +0100 Subject: COMPMID-1537: Fix weights retention in CLFullyConnectedLayer Change-Id: Id978c34889b86fa8b9184d3349cc9b12837141a2 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/145403 Reviewed-by: Anthony Barbier Tested-by: Jenkins --- arm_compute/core/Types.h | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) (limited to 'arm_compute/core/Types.h') diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index d9109e4565..37a8850237 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -1141,7 +1141,7 @@ class GEMMInfo public: /** Default constructor */ GEMMInfo() - : _is_a_reshaped(false), _is_b_reshaped(false), _reshape_b_only_on_first_run(false), _depth_output_gemm3d(1), _reinterpret_input_as_3d(false) + : _is_a_reshaped(false), _is_b_reshaped(false), _reshape_b_only_on_first_run(false), _depth_output_gemm3d(1), _reinterpret_input_as_3d(false), _retain_internal_weights(false) { } /** Constructor @@ -1152,11 +1152,12 @@ public: * @param[in] depth_output_gemm3d (Optional) Depth (third dimension) of the output tensor to be used with the GEMM3D kernel * @param[in] reinterpret_input_as_3d (Optional) Reinterpret the input as 3D tensor. (i.e. this flag should be set to true when GEMM is used * to perform 1x1 convolutions with the NHWC data layout) + * @param[in] retain_internal_weights (Optional) Retain the weights tensor from previous run * */ - GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 1, bool reinterpret_input_as_3d = false) + GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 1, bool reinterpret_input_as_3d = false, bool retain_internal_weights = false) : _is_a_reshaped(is_a_reshaped), _is_b_reshaped(is_b_reshaped), _reshape_b_only_on_first_run(reshape_b_only_on_first_run), _depth_output_gemm3d(depth_output_gemm3d), - _reinterpret_input_as_3d(reinterpret_input_as_3d) + _reinterpret_input_as_3d(reinterpret_input_as_3d), _retain_internal_weights(retain_internal_weights) { } /** Flag which specifies if the matrix A has been reshaped @@ -1201,6 +1202,14 @@ public: { return _reinterpret_input_as_3d; }; + /** Flag which specifies if the weights tensor has to be retained from previous run + * + * @return True if the weights tensor has to be retained + */ + bool retain_internal_weights() const + { + return _retain_internal_weights; + }; private: const bool _is_a_reshaped; @@ -1208,6 +1217,7 @@ private: const bool _reshape_b_only_on_first_run; const int _depth_output_gemm3d; const bool _reinterpret_input_as_3d; + const bool _retain_internal_weights; }; /** Winograd information */ -- cgit v1.2.1