aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/NEON/AssemblyHelper.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/runtime/NEON/AssemblyHelper.h')
-rw-r--r--arm_compute/runtime/NEON/AssemblyHelper.h60
1 files changed, 54 insertions, 6 deletions
diff --git a/arm_compute/runtime/NEON/AssemblyHelper.h b/arm_compute/runtime/NEON/AssemblyHelper.h
index 2b304b8022..e2d27cf941 100644
--- a/arm_compute/runtime/NEON/AssemblyHelper.h
+++ b/arm_compute/runtime/NEON/AssemblyHelper.h
@@ -40,26 +40,38 @@
namespace arm_compute
{
+/** Assembly kernel glue */
template <typename TypeInput, typename TypeOutput>
class AssemblyKernelGlue final
{
public:
+ /** Operator type */
using TypeOperator = TypeInput;
- using TypeResult = TypeOutput;
+ /** Result type */
+ using TypeResult = TypeOutput;
+ /** Default constructor. */
AssemblyKernelGlue()
: _gemm_kernel_asm(nullptr), _optimised_kernel(nullptr), _a(nullptr), _b(nullptr), _d(nullptr)
{
}
+ /** Assembly Gemm */
using AssemblyGemm = arm_gemm::GemmCommon<TypeInput, TypeOutput>;
+ /** Prevent instances of this class from being copy constructed */
const AssemblyKernelGlue<TypeInput, TypeOutput> &operator=(const AssemblyKernelGlue<TypeInput, TypeOutput> &) = delete;
+ /** Prevent instances of this class from being copied */
AssemblyKernelGlue(const AssemblyKernelGlue<TypeInput, TypeOutput> &) = delete;
+ /** Assembly Gemm kernel */
std::unique_ptr<AssemblyGemm> _gemm_kernel_asm;
- std::unique_ptr<INEKernel> _optimised_kernel;
- const ITensor *_a;
- const ITensor *_b;
- ITensor *_d;
+ /** Optimised NEON kernel */
+ std::unique_ptr<INEKernel> _optimised_kernel;
+ /** Input A */
+ const ITensor *_a;
+ /** Input B */
+ const ITensor *_b;
+ /** Output */
+ ITensor *_d;
/** Configures the arrays pointers and strides in the assembly kernel and executes the assembly kernel.
* The call to set_arrays is needed to deal with the input sizes containing batches (dims > 2)
@@ -91,10 +103,21 @@ public:
}
};
-using AssemblyKernelGlueF32 = AssemblyKernelGlue<float, float>;
+/** Float 32 assembly kernel glue */
+using AssemblyKernelGlueF32 = AssemblyKernelGlue<float, float>;
+/** Uint 8 to Uint 32 kernel glue */
using AssemblyKernelGlueU8U32 = AssemblyKernelGlue<uint8_t, uint32_t>;
+/** Int 8 to Int 32 kernel glue */
using AssemblyKernelGlueS8S32 = AssemblyKernelGlue<int8_t, int32_t>;
+/** Allocate a workspace tensor.
+ *
+ * @param[in] workspace_size Size to allocate.
+ * @param[out] workspace Tensor to allocate.
+ * @param[in] memory_group Tensor memory group.
+ * @param[in] alignment Workspace memory alignment.
+ * @param[in] num_threads Number of workspace threads.
+ */
inline void allocate_workspace(size_t workspace_size, Tensor &workspace, MemoryGroup &memory_group, size_t alignment, unsigned int num_threads)
{
ARM_COMPUTE_ERROR_ON_MSG(workspace_size == 0, "size cannot be 0");
@@ -102,6 +125,17 @@ inline void allocate_workspace(size_t workspace_size, Tensor &workspace, MemoryG
workspace.allocator()->allocate();
}
+/** Create a wrapper kernel.
+ *
+ * @param[in] a Input tensor A.
+ * @param[in] b Input tensor B.
+ * @param[in] c (Optional) Input tensor C.
+ * @param[out] d Output tensor.
+ * @param[in] alpha Alpha value.
+ * @param[in] beta Beta value.
+ *
+ * @return the wrapper kernel.
+ */
template <typename T>
std::unique_ptr<NEGEMMAssemblyWrapper<T>> create_wrapper_kernel(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta)
{
@@ -128,6 +162,20 @@ std::unique_ptr<NEGEMMAssemblyWrapper<T>> create_wrapper_kernel(const ITensor *a
return nullptr;
}
+/** Setup assembly kernel.
+ *
+ * @param[in] a Input tensor A.
+ * @param[in] b Input tensor B.
+ * @param[in] c (Optional) Input tensor C.
+ * @param[in] d Output tensor.
+ * @param[in] alpha Alpha value.
+ * @param[in] beta Beta value.
+ * @param[out] workspace Workspace tensor
+ * @param[in] memory_group Tensor memory group.
+ * @param[out] asm_glue Assembly glue kernel.
+ *
+ * @return True if the assembly kernel is setup correctly.
+ */
template <typename T>
inline bool setup_assembly_kernel(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta,
Tensor &workspace, MemoryGroup &memory_group, T &asm_glue)