diff options
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h')
-rw-r--r-- | arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h index e076f51b26..f71e2a33f9 100644 --- a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h +++ b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h @@ -30,6 +30,7 @@ #include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h" #include "arm_compute/core/CL/kernels/CLIm2ColKernel.h" #include "arm_compute/core/CL/kernels/CLTransposeKernel.h" +#include "arm_compute/runtime/CL/CLMemoryGroup.h" #include "arm_compute/runtime/CL/CLTensor.h" namespace arm_compute @@ -64,7 +65,7 @@ class CLFullyConnectedLayer : public IFunction { public: /** Constructor */ - CLFullyConnectedLayer(); + CLFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager = nullptr); /** Set the input and output tensors. * * @param[in] input Source tensor. Data type supported: QS8/QS16/F16/F32. @@ -83,6 +84,7 @@ private: void configure_fc_fc(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output); void configure_conv_fc(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output); + CLMemoryGroup _memory_group; CLIm2ColKernel _im2col_kernel; CLFullyConnectedLayerReshapeWeights _reshape_weights_kernel; CLGEMMMatrixMultiplyKernel _mm_kernel; |