aboutsummaryrefslogtreecommitdiff
path: root/src/core/cpu
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/cpu')
-rw-r--r--src/core/cpu/kernels/assembly/arm_gemm.hpp10
-rw-r--r--src/core/cpu/kernels/assembly/gemm_common.hpp7
2 files changed, 12 insertions, 5 deletions
diff --git a/src/core/cpu/kernels/assembly/arm_gemm.hpp b/src/core/cpu/kernels/assembly/arm_gemm.hpp
index 81e355d6b3..e38cc09202 100644
--- a/src/core/cpu/kernels/assembly/arm_gemm.hpp
+++ b/src/core/cpu/kernels/assembly/arm_gemm.hpp
@@ -44,9 +44,7 @@ enum class GemmMethod
GEMM_INTERLEAVED_2D,
QUANTIZE_WRAPPER,
QUANTIZE_WRAPPER_2D,
- GEMM_HYBRID_QUANTIZED,
- INDIRECT_GEMM,
- CONVOLUTION_GEMM
+ GEMM_HYBRID_QUANTIZED
};
struct KernelDescription
@@ -113,13 +111,15 @@ public:
bool _indirect_input;
Activation _act;
int _maxthreads;
+ bool _fast_mode;
const GemmConfig *_cfg;
GemmArgs(const CPUInfo *ci, unsigned int M, unsigned int N,
unsigned int K, unsigned int Ksections, unsigned int nbatches,
unsigned int nmulti, bool indirect_input, Activation act, const int maxthreads,
- const GemmConfig *cfg = nullptr)
- : _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _Ksections(Ksections), _nbatches(nbatches), _nmulti(nmulti), _indirect_input(indirect_input), _act(act), _maxthreads(maxthreads), _cfg(cfg)
+ bool fast_mode = false, const GemmConfig *cfg = nullptr)
+ : _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _Ksections(Ksections), _nbatches(nbatches), _nmulti(nmulti), _indirect_input(indirect_input), _act(act), _maxthreads(maxthreads), _fast_mode(fast_mode),
+ _cfg(cfg)
{
}
};
diff --git a/src/core/cpu/kernels/assembly/gemm_common.hpp b/src/core/cpu/kernels/assembly/gemm_common.hpp
index 4af85ed663..378f1041be 100644
--- a/src/core/cpu/kernels/assembly/gemm_common.hpp
+++ b/src/core/cpu/kernels/assembly/gemm_common.hpp
@@ -30,6 +30,9 @@
namespace arm_gemm
{
+// Avoid circular dependency with arm_gemm.hpp
+struct GemmConfig;
+
// Abstract class for the GEMM/GEMV functions.
//
// GEMM implementations may be "native" (never require any input
@@ -137,6 +140,10 @@ public:
{
}
+ /*** Introspection interface ***/
+ /* Get the configuration of this GEMM */
+ virtual GemmConfig get_config() = 0;
+
// Destructor
virtual ~IGemmCommon()
{