From b62280aca3148dd6762e57e5af3da0cb0a9e2db5 Mon Sep 17 00:00:00 2001 From: Michele Di Giorgio Date: Thu, 31 May 2018 17:31:05 +0100 Subject: COMPMID-1244: Allow retaining weights in CLGEMMConvolutionLayer and CLFullyConnectedLayer Change-Id: I1c3b2197906cd4b905309bbd5f2012bbae6a7dba Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/133730 Tested-by: Jenkins Reviewed-by: Anthony Barbier --- arm_compute/core/Types.h | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) (limited to 'arm_compute/core/Types.h') diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index 46e6dba1a0..639170f0fd 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -946,18 +946,19 @@ class WeightsInfo public: /** Default constructor */ WeightsInfo() - : _are_reshaped(false), _kernel_width(0), _kernel_height(0), _num_kernels(0) + : _are_reshaped(false), _kernel_width(0), _kernel_height(0), _num_kernels(0), _retain_internal_weights(false) { } /** Constructor * - * @param[in] are_reshaped True if the weights have been reshaped - * @param[in] kernel_width Kernel width. - * @param[in] kernel_height Kernel height. - * @param[in] num_kernels Number of convolution kernels. + * @param[in] are_reshaped True if the weights have been reshaped + * @param[in] kernel_width Kernel width. + * @param[in] kernel_height Kernel height. + * @param[in] num_kernels Number of convolution kernels. + * @param[in] retain_internal_weights (Optional) True if internal reshaped weights must be retained. Used for reconfiguration purposes. Default is false. */ - WeightsInfo(bool are_reshaped, unsigned int kernel_width, unsigned int kernel_height, unsigned int num_kernels) - : _are_reshaped(are_reshaped), _kernel_width(kernel_width), _kernel_height(kernel_height), _num_kernels(num_kernels) + WeightsInfo(bool are_reshaped, unsigned int kernel_width, unsigned int kernel_height, unsigned int num_kernels, bool retain_internal_weights = false) + : _are_reshaped(are_reshaped), _kernel_width(kernel_width), _kernel_height(kernel_height), _num_kernels(num_kernels), _retain_internal_weights(retain_internal_weights) { } /** Flag which specifies if the weights tensor has been reshaped. @@ -984,12 +985,17 @@ public: { return std::make_pair(_kernel_width, _kernel_height); } + bool retain_internal_weights() const + { + return _retain_internal_weights; + } private: const bool _are_reshaped; const unsigned int _kernel_width; const unsigned int _kernel_height; const unsigned int _num_kernels; + const bool _retain_internal_weights; }; /** GEMM reshape information class. This class stores the necessary information about matrix A and matrix B reshape. -- cgit v1.2.1