diff options
Diffstat (limited to 'src/cpu/kernels/CpuGemmMatrixAdditionKernel.h')
-rw-r--r-- | src/cpu/kernels/CpuGemmMatrixAdditionKernel.h | 16 |
1 files changed, 13 insertions, 3 deletions
diff --git a/src/cpu/kernels/CpuGemmMatrixAdditionKernel.h b/src/cpu/kernels/CpuGemmMatrixAdditionKernel.h index 4a748218d1..cbc5b53087 100644 --- a/src/cpu/kernels/CpuGemmMatrixAdditionKernel.h +++ b/src/cpu/kernels/CpuGemmMatrixAdditionKernel.h @@ -43,7 +43,16 @@ namespace kernels */ class CpuGemmMatrixAdditionKernel : public ICpuKernel<CpuGemmMatrixAdditionKernel> { +private: + using GemmMatrixAddKernelPtr = std::add_pointer<void(const ITensor *, ITensor *, const Window &, float)>::type; + public: + struct GemmMatrixAddKernel + { + const char *name; + const DataTypeISASelectorPtr is_selected; + GemmMatrixAddKernelPtr ukernel; + }; CpuGemmMatrixAdditionKernel() = default; ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuGemmMatrixAdditionKernel); /** Initialise the kernel's input and output. @@ -69,6 +78,8 @@ public: void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; const char *name() const override; + static const std::vector<GemmMatrixAddKernel> &get_available_kernels(); + private: /** Common signature for all the matrix addition functions * @@ -77,10 +88,9 @@ private: * @param[in] window Region on which to execute the kernel. * @param[in] beta Weight of matrix C */ - using MatrixAdditionFunctionPtr = void (*)(const ITensor *src, ITensor *dst, const Window &window, float beta); /** Matrix addition function to use for the particular tensor types passed to configure() */ - MatrixAdditionFunctionPtr _func{ nullptr }; - float _beta{ 0.f }; + GemmMatrixAddKernelPtr _func{ nullptr }; + float _beta{ 0.f }; }; } // namespace kernels } // namespace cpu |