aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h
diff options
context:
space:
mode:
authorDana Zlotnik <dana.zlotnik@arm.com>2022-02-02 15:06:11 +0200
committerDana Zlotnik <dana.zlotnik@arm.com>2022-02-14 12:56:32 +0000
commit256ac62029d835c55b08951af0bdfa542b878956 (patch)
treeffcc19a06e96c7cb74b6202cc19f104dc84e76e4 /src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h
parent149203bc23d5c84fe1326d9dea4730750fab6710 (diff)
downloadComputeLibrary-256ac62029d835c55b08951af0bdfa542b878956.tar.gz
Decouple CpuGemmMatrixMultiplyKernel and CpuGemmMatrixAdditionKernel
Resolves COMPMID-4629, COMPMID-4631 Change-Id: Idceafc84735116ef63ec13a202895f954b87e32f Signed-off-by: Dana Zlotnik <dana.zlotnik@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7095 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: SiCong Li <sicong.li@arm.com> Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h')
-rw-r--r--src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h18
1 files changed, 15 insertions, 3 deletions
diff --git a/src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h b/src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h
index 9c3dc8b1a0..a7dfec87bd 100644
--- a/src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h
+++ b/src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h
@@ -41,7 +41,17 @@ namespace kernels
*/
class CpuGemmMatrixMultiplyKernel : public ICpuKernel<CpuGemmMatrixMultiplyKernel>
{
+private:
+ using GemmMatrixMulKernelPtr = std::add_pointer<void(const ITensor *, const ITensor *, ITensor *, const Window &, const ThreadInfo &, float, const bool)>::type;
+
public:
+ struct GemmMatrixMulKernel
+ {
+ const char *name;
+ const DataTypeISASelectorPtr is_selected;
+ GemmMatrixMulKernelPtr ukernel;
+ };
+
CpuGemmMatrixMultiplyKernel() = default;
ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuGemmMatrixMultiplyKernel);
/** Initialise the kernel's input and output.
@@ -70,6 +80,8 @@ public:
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
const char *name() const override;
+ static const std::vector<GemmMatrixMulKernel> &get_available_kernels();
+
private:
/** Common signature for all the matrix multiply functions
*
@@ -80,10 +92,10 @@ private:
* @param[in] info Thread info metadata.
* @param[in] alpha Weight of the matrix product.
*/
- using GemmFunctionPtr = void(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha);
+
/** Matrix multiply function to use for the particular tensor types passed to configure() */
- GemmFunctionPtr *_func{ nullptr };
- float _alpha{ 1.f };
+ GemmMatrixMulKernelPtr _func{ nullptr };
+ float _alpha{ 1.f };
};
} // namespace kernels
} // namespace cpu