diff options
Diffstat (limited to 'arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp')
-rw-r--r-- | arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp | 88 |
1 files changed, 58 insertions, 30 deletions
diff --git a/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp b/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp index 162cbc5c46..26c1f3df89 100644 --- a/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp +++ b/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -24,6 +24,7 @@ #pragma once #include <memory> +#include <cstring> #include "arm_gemm_local.hpp" #include "gemm_common.hpp" @@ -37,45 +38,57 @@ enum class GemmMethod GEMV_PRETRANSPOSED, GEMV_NATIVE_TRANSPOSED, GEMM_NATIVE, - GEMM_INTERLEAVED, - GEMM_INTERLEAVED_FP16, - GEMM_INTERLEAVED_DOT + GEMM_HYBRID, + GEMM_INTERLEAVED +}; + + +struct KernelDescription +{ + GemmMethod method = GemmMethod::DEFAULT; + std::string name = ""; + + KernelDescription(GemmMethod m, std::string n) : method(m), name(n) { } + KernelDescription() { } }; struct GemmConfig { - GemmMethod method = GemmMethod::DEFAULT; + GemmMethod method = GemmMethod::DEFAULT; + std::string filter = ""; unsigned int inner_block_size = 0; unsigned int outer_block_size = 0; GemmConfig(GemmMethod method) : method(method) { } + GemmConfig() { } }; template<typename T> struct GemmArgs { public: - const CPUInfo *_ci; - unsigned int _Msize; - unsigned int _Nsize; - unsigned int _Ksize; - unsigned int _nbatches; - unsigned int _nmulti; - bool _trA; - bool _trB; - T _alpha; - T _beta; - int _maxthreads; - bool _pretransposed_hint; + const CPUInfo *_ci; + unsigned int _Msize; + unsigned int _Nsize; + unsigned int _Ksize; + unsigned int _nbatches; + unsigned int _nmulti; + bool _trA; + bool _trB; + T _alpha; + T _beta; + int _maxthreads; + bool _pretransposed_hint; + const GemmConfig *_cfg; GemmArgs(const CPUInfo *ci, const unsigned int M, const unsigned int N, const unsigned int K, const unsigned int nbatches, const unsigned int nmulti, const bool trA, const bool trB, const T alpha, const T beta, const int maxthreads, - const bool pretransposed_hint) : - _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmulti(nmulti), - _trA(trA), _trB(trB), _alpha(alpha), _beta(beta), _maxthreads(maxthreads), - _pretransposed_hint(pretransposed_hint) + const bool pretransposed_hint, const GemmConfig *cfg=nullptr ) : + _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmulti(nmulti), + _trA(trA), _trB(trB), _alpha(alpha), _beta(beta), _maxthreads(maxthreads), + _pretransposed_hint(pretransposed_hint), _cfg(cfg) { } }; @@ -90,7 +103,7 @@ using UniqueGemmCommon = std::unique_ptr<GemmCommon<Top, Tret> >; * provided parameters be provided using the supplied method? */ template<typename Top, typename Tret> -bool method_is_compatible(GemmMethod method, GemmArgs<Tret> &args); +bool method_is_compatible(GemmMethod method, const GemmArgs<Tret> &args); template<typename Top, typename Tret> bool method_is_compatible(GemmMethod method, const CPUInfo &ci, @@ -107,14 +120,14 @@ bool method_is_compatible(GemmMethod method, const CPUInfo &ci, /* get_gemm_method(): Given the templated types and provided parameters, * which is the preferred method to implement this GEMM? */ template<typename Top, typename Tret> -GemmMethod get_gemm_method(GemmArgs<Tret> &args); +KernelDescription get_gemm_method(const GemmArgs<Tret> &args); template<typename Top, typename Tret> -GemmMethod get_gemm_method(const CPUInfo &ci, - const unsigned int M, const unsigned int N, const unsigned int K, - const unsigned int nbatches, const unsigned int nmulti, - const bool trA, const bool trB, const Tret alpha, const Tret beta, - const int maxthreads, const bool pretransposed_hint) +KernelDescription get_gemm_method(const CPUInfo &ci, + const unsigned int M, const unsigned int N, const unsigned int K, + const unsigned int nbatches, const unsigned int nmulti, + const bool trA, const bool trB, const Tret alpha, const Tret beta, + const int maxthreads, const bool pretransposed_hint) { GemmArgs<Tret> args(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint); @@ -122,7 +135,7 @@ GemmMethod get_gemm_method(const CPUInfo &ci, } template<typename Top, typename Tret> -UniqueGemmCommon<Top, Tret> gemm(GemmArgs<Tret> &args, GemmConfig *cfg); +UniqueGemmCommon<Top, Tret> gemm(const GemmArgs<Tret> &args); /** Request an object to process a GEMM. * @@ -147,9 +160,24 @@ UniqueGemmCommon<Top, Tret> gemm(const CPUInfo &ci, const bool trA, const bool trB, const Tret alpha, const Tret beta, const int maxthreads, const bool pretransposed_hint, GemmConfig *cfg=nullptr) { + GemmArgs<Tret> args(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint, cfg); + + return gemm<Top, Tret>(args); +} + +template<typename Top, typename Tret> +std::vector<std::string> get_compatible_kernels(const GemmArgs<Tret> &args); + +template<typename Top, typename Tret> +std::vector<std::string> get_compatible_kernels(const CPUInfo &ci, + const unsigned int M, const unsigned int N, const unsigned int K, + const unsigned int nbatches, const unsigned int nmulti, + const bool trA, const bool trB, const Tret alpha, const Tret beta, + const int maxthreads, const bool pretransposed_hint) +{ GemmArgs<Tret> args(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint); - return gemm<Top, Tret>(args, cfg); + return get_compatible_kernels<Top, Tret>(args); } } // namespace arm_gemm |