diff options
Diffstat (limited to 'arm_compute/core/NEON/kernels/NEGEMMInterleave4x4Kernel.h')
-rw-r--r-- | arm_compute/core/NEON/kernels/NEGEMMInterleave4x4Kernel.h | 25 |
1 files changed, 18 insertions, 7 deletions
diff --git a/arm_compute/core/NEON/kernels/NEGEMMInterleave4x4Kernel.h b/arm_compute/core/NEON/kernels/NEGEMMInterleave4x4Kernel.h index 7ddbf4bca8..322932bab2 100644 --- a/arm_compute/core/NEON/kernels/NEGEMMInterleave4x4Kernel.h +++ b/arm_compute/core/NEON/kernels/NEGEMMInterleave4x4Kernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2019 Arm Limited. + * Copyright (c) 2016-2020 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -77,15 +77,26 @@ public: void run(const Window &window, const ThreadInfo &info) override; private: - /** Common signature for all the transpose functions + /** Template function to run gemm interleave 4x4 * - * @param[in] input An input tensor. Data types supported: All - * @param[out] output The output tensor. Data type supported: same as @p input - * @param[in] window Region on which to execute the kernel. + * @tparam ScalarType Scalar datatype + * + * @param[in] input Input tensor. Data types supported: uint32_t, uint16_t and uint8_t + * @param[out] output Output tensor. Data types supported: uint32_t, uint16_t and uint8_t + * @param[in] window Region on which to execute the kernel. (Must be a valid region of the window returned by window()). + */ + template <typename ScalarType> + void gemm_interleave4x4(const ITensor *input, ITensor *output, const Window &window); + + /** Common signature for all the specialised gemm interleave 4x4 functions + * + * @param[in] input Input tensor. Data types supported: uint32_t, uint16_t and uint8_t + * @param[out] output Output tensor. Data types supported: uint32_t, uint16_t and uint8_t + * @param[in] window Region on which to execute the kernel. (Must be a valid region of the window returned by window()). */ - using GEMMInterleaveFunction = void(const ITensor *input, ITensor *output, const Window &window); + using GEMMInterleaveFunctionFuncPtr = void (NEGEMMInterleave4x4Kernel::*)(const ITensor *input, ITensor *output, const Window &window); - GEMMInterleaveFunction *_func; /**< GEMM interleave function to use for the particular tensor types passed to configure() */ + GEMMInterleaveFunctionFuncPtr _func; }; } // namespace arm_compute #endif /*ARM_COMPUTE_NEGEMMINTERLEAVE4x4KERNEL_H*/ |