diff options
Diffstat (limited to 'arm_compute/core')
-rw-r--r-- | arm_compute/core/Helpers.inl | 2 | ||||
-rw-r--r-- | arm_compute/core/Types.h | 18 |
2 files changed, 15 insertions, 5 deletions
diff --git a/arm_compute/core/Helpers.inl b/arm_compute/core/Helpers.inl index d342a60987..78e0c70e1b 100644 --- a/arm_compute/core/Helpers.inl +++ b/arm_compute/core/Helpers.inl @@ -251,8 +251,8 @@ inline bool auto_init_if_empty(ITensorInfo &info, const TensorShape &shape, int if(info.tensor_shape().total_size() == 0) { info.set_data_type(data_type); - info.set_tensor_shape(shape); info.set_num_channels(num_channels); + info.set_tensor_shape(shape); info.set_fixed_point_position(fixed_point_position); return true; } diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index 3092976149..1d04f35359 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -607,13 +607,13 @@ private: float _kappa; }; -/** Convolution Layer Weights Information class */ +/** Convolution Layer Weights Information class. This class stores the necessary information to compute convolution layer when the weights are already reshaped */ class WeightsInfo { public: /** Default constructor */ WeightsInfo() - : _are_reshaped(false), _kernel_width(0), _kernel_height(0) + : _are_reshaped(false), _kernel_width(0), _kernel_height(0), _num_kernels(0) { } /** Constructor @@ -621,9 +621,10 @@ public: * @param[in] are_reshaped True if the weights have been reshaped * @param[in] kernel_width Kernel width. * @param[in] kernel_height Kernel height. + * @param[in] num_kernels Number of convolution kernels. */ - WeightsInfo(bool are_reshaped, unsigned int kernel_width, unsigned int kernel_height) - : _are_reshaped(are_reshaped), _kernel_width(kernel_width), _kernel_height(kernel_height) + WeightsInfo(bool are_reshaped, unsigned int kernel_width, unsigned int kernel_height, unsigned int num_kernels) + : _are_reshaped(are_reshaped), _kernel_width(kernel_width), _kernel_height(kernel_height), _num_kernels(num_kernels) { } /** Flag which specifies if the weights tensor has been reshaped. @@ -634,6 +635,14 @@ public: { return _are_reshaped; }; + /** Return the number of convolution kernels + * + * @return The number of convolution kernels + */ + unsigned int num_kernels() const + { + return _num_kernels; + }; /** Return the width and height of the kernel * * @return The width and height of the kernel @@ -647,6 +656,7 @@ private: const bool _are_reshaped; const unsigned int _kernel_width; const unsigned int _kernel_height; + const unsigned int _num_kernels; }; /** IO formatting information class*/ |