diff options
Diffstat (limited to 'src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h')
-rw-r--r-- | src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h | 18 |
1 files changed, 15 insertions, 3 deletions
diff --git a/src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h b/src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h index 9c3dc8b1a0..a7dfec87bd 100644 --- a/src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h +++ b/src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h @@ -41,7 +41,17 @@ namespace kernels */ class CpuGemmMatrixMultiplyKernel : public ICpuKernel<CpuGemmMatrixMultiplyKernel> { +private: + using GemmMatrixMulKernelPtr = std::add_pointer<void(const ITensor *, const ITensor *, ITensor *, const Window &, const ThreadInfo &, float, const bool)>::type; + public: + struct GemmMatrixMulKernel + { + const char *name; + const DataTypeISASelectorPtr is_selected; + GemmMatrixMulKernelPtr ukernel; + }; + CpuGemmMatrixMultiplyKernel() = default; ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuGemmMatrixMultiplyKernel); /** Initialise the kernel's input and output. @@ -70,6 +80,8 @@ public: void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; const char *name() const override; + static const std::vector<GemmMatrixMulKernel> &get_available_kernels(); + private: /** Common signature for all the matrix multiply functions * @@ -80,10 +92,10 @@ private: * @param[in] info Thread info metadata. * @param[in] alpha Weight of the matrix product. */ - using GemmFunctionPtr = void(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha); + /** Matrix multiply function to use for the particular tensor types passed to configure() */ - GemmFunctionPtr *_func{ nullptr }; - float _alpha{ 1.f }; + GemmMatrixMulKernelPtr _func{ nullptr }; + float _alpha{ 1.f }; }; } // namespace kernels } // namespace cpu |