aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h')
-rw-r--r--arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h76
1 files changed, 67 insertions, 9 deletions
diff --git a/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h b/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h
index 0b27c824d9..017bf78938 100644
--- a/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h
+++ b/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h
@@ -39,6 +39,8 @@
#include "arm_compute/runtime/CL/functions/CLGEMMLowpOutputStage.h"
#include "arm_compute/runtime/CL/functions/CLReshapeLayer.h"
#include "arm_compute/runtime/IMemoryManager.h"
+#include "arm_compute/runtime/ITransformWeights.h"
+#include "arm_compute/runtime/IWeightsManager.h"
#include "arm_compute/runtime/MemoryGroup.h"
#include <memory>
@@ -82,6 +84,59 @@ private:
CLWeightsReshapeKernel _weights_reshape_kernel;
};
+namespace weights_transformations
+{
+/** Basic function to manage the reshape weights generated from @ref CLConvolutionLayerReshapeWeights */
+class CLConvolutionLayerReshapeWeightsTransform : public ITransformWeights
+{
+public:
+ /** Configures the @ref CLConvolutionLayerReshapeWeights function
+ *
+ * @param[in] input Input tensor. Data type supported: QASYMM8/F16/F32.
+ * @param[in] biases Biases tensor. Data type supported: Same as @p input.
+ * @param[in] num_groups Number of groups when performing a grouped convolution.
+ */
+ void configure(const ICLTensor *input, const ICLTensor *biases, unsigned int num_groups)
+ {
+ _bias_bit = (biases != nullptr) ? 1 : 0;
+ _num_groups = num_groups;
+ _func.configure(input, biases, &_output, num_groups);
+ }
+
+ //Inherited method override
+ void run() override
+ {
+ _output.allocator()->allocate();
+ _func.run();
+ _reshape_run = true;
+ }
+
+ //Inherited method override
+ ICLTensor *get_weights() override
+ {
+ return &_output;
+ }
+
+ //Inherited method override
+ void release() override
+ {
+ _output.allocator()->free();
+ }
+
+ //Inherited method override
+ uint32_t uid() override
+ {
+ return ((0x9) | (_bias_bit << 7) | (_num_groups << 8));
+ }
+
+private:
+ CLTensor _output{};
+ CLConvolutionLayerReshapeWeights _func{};
+ int32_t _bias_bit{ 0 };
+ unsigned int _num_groups{ 0 };
+};
+} // namespace weights_transformations
+
/** Basic function to compute the convolution layer. This function calls the following OpenCL kernels/functions:
*
* -# @ref CLIm2ColKernel
@@ -96,9 +151,10 @@ class CLGEMMConvolutionLayer : public IFunction
public:
/** Constructor
*
- * @param[in] memory_manager (Optional) Memory manager.
+ * @param[in] memory_manager (Optional) Memory manager.
+ * @param[in] weights_manager (Optional) Weights manager.
*/
- CLGEMMConvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager = nullptr);
+ CLGEMMConvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager = nullptr, IWeightsManager *weights_manager = nullptr);
/** Prevent instances of this class from being copied (As this class contains pointers) */
CLGEMMConvolutionLayer(const CLGEMMConvolutionLayer &) = delete;
/** Default move constructor */
@@ -186,13 +242,15 @@ private:
int gemm_3d_depth, bool skip_im2col, const ActivationLayerInfo &act_info);
private:
- MemoryGroup _memory_group;
- CLConvolutionLayerReshapeWeights _reshape_weights;
- CLIm2ColKernel _im2col_kernel;
- CLGEMM _mm_gemm;
- CLGEMMLowpMatrixMultiplyCore _mm_gemmlowp;
- CLCol2ImKernel _col2im_kernel;
- CLActivationLayer _activationlayer_function;
+ MemoryGroup _memory_group;
+ IWeightsManager *_weights_manager;
+ CLConvolutionLayerReshapeWeights _reshape_weights;
+ weights_transformations::CLConvolutionLayerReshapeWeightsTransform _reshape_weights_managed;
+ CLIm2ColKernel _im2col_kernel;
+ CLGEMM _mm_gemm;
+ CLGEMMLowpMatrixMultiplyCore _mm_gemmlowp;
+ CLCol2ImKernel _col2im_kernel;
+ CLActivationLayer _activationlayer_function;
const ICLTensor *_original_weights;