diff options
author | Dana Zlotnik <dana.zlotnik@arm.com> | 2022-02-02 15:06:11 +0200 |
---|---|---|
committer | Dana Zlotnik <dana.zlotnik@arm.com> | 2022-02-14 12:56:32 +0000 |
commit | 256ac62029d835c55b08951af0bdfa542b878956 (patch) | |
tree | ffcc19a06e96c7cb74b6202cc19f104dc84e76e4 /src/cpu/kernels/CpuGemmMatrixAdditionKernel.h | |
parent | 149203bc23d5c84fe1326d9dea4730750fab6710 (diff) | |
download | ComputeLibrary-256ac62029d835c55b08951af0bdfa542b878956.tar.gz |
Decouple CpuGemmMatrixMultiplyKernel and CpuGemmMatrixAdditionKernel
Resolves COMPMID-4629, COMPMID-4631
Change-Id: Idceafc84735116ef63ec13a202895f954b87e32f
Signed-off-by: Dana Zlotnik <dana.zlotnik@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7095
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: SiCong Li <sicong.li@arm.com>
Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/cpu/kernels/CpuGemmMatrixAdditionKernel.h')
-rw-r--r-- | src/cpu/kernels/CpuGemmMatrixAdditionKernel.h | 16 |
1 files changed, 13 insertions, 3 deletions
diff --git a/src/cpu/kernels/CpuGemmMatrixAdditionKernel.h b/src/cpu/kernels/CpuGemmMatrixAdditionKernel.h index 4a748218d1..cbc5b53087 100644 --- a/src/cpu/kernels/CpuGemmMatrixAdditionKernel.h +++ b/src/cpu/kernels/CpuGemmMatrixAdditionKernel.h @@ -43,7 +43,16 @@ namespace kernels */ class CpuGemmMatrixAdditionKernel : public ICpuKernel<CpuGemmMatrixAdditionKernel> { +private: + using GemmMatrixAddKernelPtr = std::add_pointer<void(const ITensor *, ITensor *, const Window &, float)>::type; + public: + struct GemmMatrixAddKernel + { + const char *name; + const DataTypeISASelectorPtr is_selected; + GemmMatrixAddKernelPtr ukernel; + }; CpuGemmMatrixAdditionKernel() = default; ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuGemmMatrixAdditionKernel); /** Initialise the kernel's input and output. @@ -69,6 +78,8 @@ public: void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; const char *name() const override; + static const std::vector<GemmMatrixAddKernel> &get_available_kernels(); + private: /** Common signature for all the matrix addition functions * @@ -77,10 +88,9 @@ private: * @param[in] window Region on which to execute the kernel. * @param[in] beta Weight of matrix C */ - using MatrixAdditionFunctionPtr = void (*)(const ITensor *src, ITensor *dst, const Window &window, float beta); /** Matrix addition function to use for the particular tensor types passed to configure() */ - MatrixAdditionFunctionPtr _func{ nullptr }; - float _beta{ 0.f }; + GemmMatrixAddKernelPtr _func{ nullptr }; + float _beta{ 0.f }; }; } // namespace kernels } // namespace cpu |