diff options
Diffstat (limited to 'src/core/cpu')
-rw-r--r-- | src/core/cpu/kernels/assembly/arm_gemm.hpp | 10 | ||||
-rw-r--r-- | src/core/cpu/kernels/assembly/gemm_common.hpp | 7 |
2 files changed, 12 insertions, 5 deletions
diff --git a/src/core/cpu/kernels/assembly/arm_gemm.hpp b/src/core/cpu/kernels/assembly/arm_gemm.hpp index 81e355d6b3..e38cc09202 100644 --- a/src/core/cpu/kernels/assembly/arm_gemm.hpp +++ b/src/core/cpu/kernels/assembly/arm_gemm.hpp @@ -44,9 +44,7 @@ enum class GemmMethod GEMM_INTERLEAVED_2D, QUANTIZE_WRAPPER, QUANTIZE_WRAPPER_2D, - GEMM_HYBRID_QUANTIZED, - INDIRECT_GEMM, - CONVOLUTION_GEMM + GEMM_HYBRID_QUANTIZED }; struct KernelDescription @@ -113,13 +111,15 @@ public: bool _indirect_input; Activation _act; int _maxthreads; + bool _fast_mode; const GemmConfig *_cfg; GemmArgs(const CPUInfo *ci, unsigned int M, unsigned int N, unsigned int K, unsigned int Ksections, unsigned int nbatches, unsigned int nmulti, bool indirect_input, Activation act, const int maxthreads, - const GemmConfig *cfg = nullptr) - : _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _Ksections(Ksections), _nbatches(nbatches), _nmulti(nmulti), _indirect_input(indirect_input), _act(act), _maxthreads(maxthreads), _cfg(cfg) + bool fast_mode = false, const GemmConfig *cfg = nullptr) + : _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _Ksections(Ksections), _nbatches(nbatches), _nmulti(nmulti), _indirect_input(indirect_input), _act(act), _maxthreads(maxthreads), _fast_mode(fast_mode), + _cfg(cfg) { } }; diff --git a/src/core/cpu/kernels/assembly/gemm_common.hpp b/src/core/cpu/kernels/assembly/gemm_common.hpp index 4af85ed663..378f1041be 100644 --- a/src/core/cpu/kernels/assembly/gemm_common.hpp +++ b/src/core/cpu/kernels/assembly/gemm_common.hpp @@ -30,6 +30,9 @@ namespace arm_gemm { +// Avoid circular dependency with arm_gemm.hpp +struct GemmConfig; + // Abstract class for the GEMM/GEMV functions. // // GEMM implementations may be "native" (never require any input @@ -137,6 +140,10 @@ public: { } + /*** Introspection interface ***/ + /* Get the configuration of this GEMM */ + virtual GemmConfig get_config() = 0; + // Destructor virtual ~IGemmCommon() { |