aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp')
-rw-r--r--arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp88
1 files changed, 58 insertions, 30 deletions
diff --git a/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp b/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp
index 162cbc5c46..26c1f3df89 100644
--- a/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp
+++ b/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,6 +24,7 @@
#pragma once
#include <memory>
+#include <cstring>
#include "arm_gemm_local.hpp"
#include "gemm_common.hpp"
@@ -37,45 +38,57 @@ enum class GemmMethod
GEMV_PRETRANSPOSED,
GEMV_NATIVE_TRANSPOSED,
GEMM_NATIVE,
- GEMM_INTERLEAVED,
- GEMM_INTERLEAVED_FP16,
- GEMM_INTERLEAVED_DOT
+ GEMM_HYBRID,
+ GEMM_INTERLEAVED
+};
+
+
+struct KernelDescription
+{
+ GemmMethod method = GemmMethod::DEFAULT;
+ std::string name = "";
+
+ KernelDescription(GemmMethod m, std::string n) : method(m), name(n) { }
+ KernelDescription() { }
};
struct GemmConfig
{
- GemmMethod method = GemmMethod::DEFAULT;
+ GemmMethod method = GemmMethod::DEFAULT;
+ std::string filter = "";
unsigned int inner_block_size = 0;
unsigned int outer_block_size = 0;
GemmConfig(GemmMethod method) : method(method) { }
+ GemmConfig() { }
};
template<typename T>
struct GemmArgs
{
public:
- const CPUInfo *_ci;
- unsigned int _Msize;
- unsigned int _Nsize;
- unsigned int _Ksize;
- unsigned int _nbatches;
- unsigned int _nmulti;
- bool _trA;
- bool _trB;
- T _alpha;
- T _beta;
- int _maxthreads;
- bool _pretransposed_hint;
+ const CPUInfo *_ci;
+ unsigned int _Msize;
+ unsigned int _Nsize;
+ unsigned int _Ksize;
+ unsigned int _nbatches;
+ unsigned int _nmulti;
+ bool _trA;
+ bool _trB;
+ T _alpha;
+ T _beta;
+ int _maxthreads;
+ bool _pretransposed_hint;
+ const GemmConfig *_cfg;
GemmArgs(const CPUInfo *ci, const unsigned int M, const unsigned int N,
const unsigned int K, const unsigned int nbatches,
const unsigned int nmulti, const bool trA, const bool trB,
const T alpha, const T beta, const int maxthreads,
- const bool pretransposed_hint) :
- _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmulti(nmulti),
- _trA(trA), _trB(trB), _alpha(alpha), _beta(beta), _maxthreads(maxthreads),
- _pretransposed_hint(pretransposed_hint)
+ const bool pretransposed_hint, const GemmConfig *cfg=nullptr ) :
+ _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmulti(nmulti),
+ _trA(trA), _trB(trB), _alpha(alpha), _beta(beta), _maxthreads(maxthreads),
+ _pretransposed_hint(pretransposed_hint), _cfg(cfg)
{
}
};
@@ -90,7 +103,7 @@ using UniqueGemmCommon = std::unique_ptr<GemmCommon<Top, Tret> >;
* provided parameters be provided using the supplied method? */
template<typename Top, typename Tret>
-bool method_is_compatible(GemmMethod method, GemmArgs<Tret> &args);
+bool method_is_compatible(GemmMethod method, const GemmArgs<Tret> &args);
template<typename Top, typename Tret>
bool method_is_compatible(GemmMethod method, const CPUInfo &ci,
@@ -107,14 +120,14 @@ bool method_is_compatible(GemmMethod method, const CPUInfo &ci,
/* get_gemm_method(): Given the templated types and provided parameters,
* which is the preferred method to implement this GEMM? */
template<typename Top, typename Tret>
-GemmMethod get_gemm_method(GemmArgs<Tret> &args);
+KernelDescription get_gemm_method(const GemmArgs<Tret> &args);
template<typename Top, typename Tret>
-GemmMethod get_gemm_method(const CPUInfo &ci,
- const unsigned int M, const unsigned int N, const unsigned int K,
- const unsigned int nbatches, const unsigned int nmulti,
- const bool trA, const bool trB, const Tret alpha, const Tret beta,
- const int maxthreads, const bool pretransposed_hint)
+KernelDescription get_gemm_method(const CPUInfo &ci,
+ const unsigned int M, const unsigned int N, const unsigned int K,
+ const unsigned int nbatches, const unsigned int nmulti,
+ const bool trA, const bool trB, const Tret alpha, const Tret beta,
+ const int maxthreads, const bool pretransposed_hint)
{
GemmArgs<Tret> args(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint);
@@ -122,7 +135,7 @@ GemmMethod get_gemm_method(const CPUInfo &ci,
}
template<typename Top, typename Tret>
-UniqueGemmCommon<Top, Tret> gemm(GemmArgs<Tret> &args, GemmConfig *cfg);
+UniqueGemmCommon<Top, Tret> gemm(const GemmArgs<Tret> &args);
/** Request an object to process a GEMM.
*
@@ -147,9 +160,24 @@ UniqueGemmCommon<Top, Tret> gemm(const CPUInfo &ci,
const bool trA, const bool trB, const Tret alpha, const Tret beta,
const int maxthreads, const bool pretransposed_hint, GemmConfig *cfg=nullptr)
{
+ GemmArgs<Tret> args(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint, cfg);
+
+ return gemm<Top, Tret>(args);
+}
+
+template<typename Top, typename Tret>
+std::vector<std::string> get_compatible_kernels(const GemmArgs<Tret> &args);
+
+template<typename Top, typename Tret>
+std::vector<std::string> get_compatible_kernels(const CPUInfo &ci,
+ const unsigned int M, const unsigned int N, const unsigned int K,
+ const unsigned int nbatches, const unsigned int nmulti,
+ const bool trA, const bool trB, const Tret alpha, const Tret beta,
+ const int maxthreads, const bool pretransposed_hint)
+{
GemmArgs<Tret> args(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint);
- return gemm<Top, Tret>(args, cfg);
+ return get_compatible_kernels<Top, Tret>(args);
}
} // namespace arm_gemm