diff options
Diffstat (limited to 'arm_compute/runtime/NEON/AssemblyHelper.h')
-rw-r--r-- | arm_compute/runtime/NEON/AssemblyHelper.h | 60 |
1 files changed, 54 insertions, 6 deletions
diff --git a/arm_compute/runtime/NEON/AssemblyHelper.h b/arm_compute/runtime/NEON/AssemblyHelper.h index 2b304b8022..e2d27cf941 100644 --- a/arm_compute/runtime/NEON/AssemblyHelper.h +++ b/arm_compute/runtime/NEON/AssemblyHelper.h @@ -40,26 +40,38 @@ namespace arm_compute { +/** Assembly kernel glue */ template <typename TypeInput, typename TypeOutput> class AssemblyKernelGlue final { public: + /** Operator type */ using TypeOperator = TypeInput; - using TypeResult = TypeOutput; + /** Result type */ + using TypeResult = TypeOutput; + /** Default constructor. */ AssemblyKernelGlue() : _gemm_kernel_asm(nullptr), _optimised_kernel(nullptr), _a(nullptr), _b(nullptr), _d(nullptr) { } + /** Assembly Gemm */ using AssemblyGemm = arm_gemm::GemmCommon<TypeInput, TypeOutput>; + /** Prevent instances of this class from being copy constructed */ const AssemblyKernelGlue<TypeInput, TypeOutput> &operator=(const AssemblyKernelGlue<TypeInput, TypeOutput> &) = delete; + /** Prevent instances of this class from being copied */ AssemblyKernelGlue(const AssemblyKernelGlue<TypeInput, TypeOutput> &) = delete; + /** Assembly Gemm kernel */ std::unique_ptr<AssemblyGemm> _gemm_kernel_asm; - std::unique_ptr<INEKernel> _optimised_kernel; - const ITensor *_a; - const ITensor *_b; - ITensor *_d; + /** Optimised NEON kernel */ + std::unique_ptr<INEKernel> _optimised_kernel; + /** Input A */ + const ITensor *_a; + /** Input B */ + const ITensor *_b; + /** Output */ + ITensor *_d; /** Configures the arrays pointers and strides in the assembly kernel and executes the assembly kernel. * The call to set_arrays is needed to deal with the input sizes containing batches (dims > 2) @@ -91,10 +103,21 @@ public: } }; -using AssemblyKernelGlueF32 = AssemblyKernelGlue<float, float>; +/** Float 32 assembly kernel glue */ +using AssemblyKernelGlueF32 = AssemblyKernelGlue<float, float>; +/** Uint 8 to Uint 32 kernel glue */ using AssemblyKernelGlueU8U32 = AssemblyKernelGlue<uint8_t, uint32_t>; +/** Int 8 to Int 32 kernel glue */ using AssemblyKernelGlueS8S32 = AssemblyKernelGlue<int8_t, int32_t>; +/** Allocate a workspace tensor. + * + * @param[in] workspace_size Size to allocate. + * @param[out] workspace Tensor to allocate. + * @param[in] memory_group Tensor memory group. + * @param[in] alignment Workspace memory alignment. + * @param[in] num_threads Number of workspace threads. + */ inline void allocate_workspace(size_t workspace_size, Tensor &workspace, MemoryGroup &memory_group, size_t alignment, unsigned int num_threads) { ARM_COMPUTE_ERROR_ON_MSG(workspace_size == 0, "size cannot be 0"); @@ -102,6 +125,17 @@ inline void allocate_workspace(size_t workspace_size, Tensor &workspace, MemoryG workspace.allocator()->allocate(); } +/** Create a wrapper kernel. + * + * @param[in] a Input tensor A. + * @param[in] b Input tensor B. + * @param[in] c (Optional) Input tensor C. + * @param[out] d Output tensor. + * @param[in] alpha Alpha value. + * @param[in] beta Beta value. + * + * @return the wrapper kernel. + */ template <typename T> std::unique_ptr<NEGEMMAssemblyWrapper<T>> create_wrapper_kernel(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta) { @@ -128,6 +162,20 @@ std::unique_ptr<NEGEMMAssemblyWrapper<T>> create_wrapper_kernel(const ITensor *a return nullptr; } +/** Setup assembly kernel. + * + * @param[in] a Input tensor A. + * @param[in] b Input tensor B. + * @param[in] c (Optional) Input tensor C. + * @param[in] d Output tensor. + * @param[in] alpha Alpha value. + * @param[in] beta Beta value. + * @param[out] workspace Workspace tensor + * @param[in] memory_group Tensor memory group. + * @param[out] asm_glue Assembly glue kernel. + * + * @return True if the assembly kernel is setup correctly. + */ template <typename T> inline bool setup_assembly_kernel(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta, Tensor &workspace, MemoryGroup &memory_group, T &asm_glue) |