diff options
Diffstat (limited to 'arm_compute/runtime/NEON/functions/NEGEMM.h')
-rw-r--r-- | arm_compute/runtime/NEON/functions/NEGEMM.h | 22 |
1 files changed, 13 insertions, 9 deletions
diff --git a/arm_compute/runtime/NEON/functions/NEGEMM.h b/arm_compute/runtime/NEON/functions/NEGEMM.h index b4b9e8be01..068e7c5ce8 100644 --- a/arm_compute/runtime/NEON/functions/NEGEMM.h +++ b/arm_compute/runtime/NEON/functions/NEGEMM.h @@ -25,6 +25,7 @@ #define __ARM_COMPUTE_NEGEMM_H__ #include "arm_compute/core/NEON/kernels/NEFillBorderKernel.h" +#include "arm_compute/core/NEON/kernels/NEGEMMAssemblyBaseKernel.h" #include "arm_compute/core/NEON/kernels/NEGEMMInterleave4x4Kernel.h" #include "arm_compute/core/NEON/kernels/NEGEMMMatrixAdditionKernel.h" #include "arm_compute/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h" @@ -51,6 +52,7 @@ class NEGEMM : public IFunction public: /** Constructor */ NEGEMM(std::shared_ptr<IMemoryManager> memory_manager = nullptr); + /** Initialise the kernel's inputs, output * * @note GEMM: General Matrix Multiply - [alpha * A * B + beta * C]. @@ -69,15 +71,17 @@ public: void run() override; private: - MemoryGroup _memory_group; - NEGEMMInterleave4x4Kernel _interleave_kernel; - NEGEMMTranspose1xWKernel _transpose_kernel; - NEGEMMMatrixMultiplyKernel _mm_kernel; - NEGEMMMatrixAdditionKernel _ma_kernel; - Tensor _tmp_a; - Tensor _tmp_b; - bool _run_vector_matrix_multiplication; - bool _run_addition; + MemoryGroup _memory_group; + NEGEMMInterleave4x4Kernel _interleave_kernel; + NEGEMMTranspose1xWKernel _transpose_kernel; + NEGEMMMatrixMultiplyKernel _mm_kernel; + std::unique_ptr<NEGEMMAssemblyBaseKernel> _mm_optimised_kernel; + NEGEMMMatrixAdditionKernel _ma_kernel; + Tensor _tmp_a; + Tensor _tmp_b; + Tensor _workspace; + bool _run_vector_matrix_multiplication; + bool _run_addition; }; } #endif /*__ARM_COMPUTE_NEGEMM_H__ */ |