diff options
Diffstat (limited to 'arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h')
-rw-r--r-- | arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h b/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h index 08099b8539..463a7d53e3 100644 --- a/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h +++ b/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h @@ -32,6 +32,7 @@ #include "arm_compute/core/NEON/kernels/NEGEMMTranspose1xWKernel.h" #include "arm_compute/core/NEON/kernels/NEIm2ColKernel.h" #include "arm_compute/core/NEON/kernels/NETransposeKernel.h" +#include "arm_compute/runtime/MemoryGroup.h" #include "arm_compute/runtime/Tensor.h" namespace arm_compute @@ -47,7 +48,7 @@ class NEFullyConnectedLayerReshapeWeights : public IFunction { public: /** Constructor */ - NEFullyConnectedLayerReshapeWeights(); + NEFullyConnectedLayerReshapeWeights(std::shared_ptr<IMemoryManager> memory_manager = nullptr); /** Set the input and output tensors. * * @param[in] input Weights tensor. The weights must be 2 dimensional. Data types supported: QS8/QS16/F32. @@ -61,6 +62,7 @@ public: void run() override; private: + MemoryGroup _memory_group; NETransposeKernel _transpose_kernel; NEGEMMTranspose1xWKernel _transpose1xW_kernel; Tensor _transpose_output; @@ -81,7 +83,7 @@ class NEFullyConnectedLayer : public IFunction { public: /** Constructor */ - NEFullyConnectedLayer(); + NEFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager = nullptr); /** Set the input and output tensors. * * @param[in] input Source tensor. Data type supported: QS8/QS16/F32. @@ -97,6 +99,7 @@ public: void run() override; private: + MemoryGroup _memory_group; NEIm2ColKernel _im2col_kernel; NEFullyConnectedLayerReshapeWeights _reshape_weights_kernel; NEGEMMInterleave4x4Kernel _interleave4x4_kernel; |