From e39334c15c7fd141bb8173d5017ea5ca157fca2c Mon Sep 17 00:00:00 2001 From: David Mansell Date: Fri, 6 Jul 2018 17:53:35 +0100 Subject: COMPMID-1271: New system for GEMM heuristics This patch implements a system for separating the "validity" from "preferred" aspect of the current heuristics in gemm_*.cpp. Now, each gemm_*.cpp defines a list of candidate implementations, each of which supplies an is_valid() function (to check for validity), an is_preferred() function (the "heuristic" part), and an instantiate() function which actually produces the GemmCommon object pointer. The actual gemm() function is now templated and uses this list to select an implementation. This patch also implements a mechanism to identify the preferred implementation, and override it via the GemmConfig structure. Change-Id: Id49ab7af8bf2e3e9fd951a9698883ade234d40e1 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/139120 Reviewed-by: Anthony Barbier Tested-by: Jenkins --- src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) (limited to 'src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp') diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp index 6db55c02d0..f4b712c453 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp @@ -25,20 +25,36 @@ #include "arm_gemm.hpp" #include "gemm_common.hpp" +#include "gemm_implementation.hpp" #include "gemm_interleaved.hpp" #include "kernels/a64_gemm_u16_12x8.hpp" namespace arm_gemm { +class GemmImpl_gemm_u16_interleaved : public GemmImplementation { +public: + UniqueGemmCommon instantiate(const GemmArgs &args) override { + return UniqueGemmCommon(new GemmInterleaved(args)); + } + + GemmImpl_gemm_u16_interleaved() : GemmImplementation(GemmMethod::GEMM_INTERLEAVED) { } +}; + +static std::vector *> gemm_u16_methods = { + new GemmImpl_gemm_u16_interleaved() +}; + template<> -UniqueGemmCommon gemm(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K, - const unsigned int nbatches, const unsigned int nmulti, - const bool trA, const bool trB, uint32_t alpha, uint32_t beta, - const int maxthreads, const bool pretransposed_hint) { - return UniqueGemmCommon(new GemmInterleaved(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint)); +std::vector *> &gemm_implementation_list() { + return gemm_u16_methods; } +/* Explicitly instantiate the external functions for these types. */ +template UniqueGemmCommon gemm(GemmArgs &args, GemmConfig *cfg); +template GemmMethod get_gemm_method(GemmArgs &args); +template bool method_is_compatible(GemmMethod method, GemmArgs &args); + } // namespace arm_gemm #endif // __aarch64__ -- cgit v1.2.1