aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/CpuGemmMatrixAdditionKernel.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/CpuGemmMatrixAdditionKernel.h')
-rw-r--r--src/cpu/kernels/CpuGemmMatrixAdditionKernel.h16
1 files changed, 13 insertions, 3 deletions
diff --git a/src/cpu/kernels/CpuGemmMatrixAdditionKernel.h b/src/cpu/kernels/CpuGemmMatrixAdditionKernel.h
index 4a748218d1..cbc5b53087 100644
--- a/src/cpu/kernels/CpuGemmMatrixAdditionKernel.h
+++ b/src/cpu/kernels/CpuGemmMatrixAdditionKernel.h
@@ -43,7 +43,16 @@ namespace kernels
*/
class CpuGemmMatrixAdditionKernel : public ICpuKernel<CpuGemmMatrixAdditionKernel>
{
+private:
+ using GemmMatrixAddKernelPtr = std::add_pointer<void(const ITensor *, ITensor *, const Window &, float)>::type;
+
public:
+ struct GemmMatrixAddKernel
+ {
+ const char *name;
+ const DataTypeISASelectorPtr is_selected;
+ GemmMatrixAddKernelPtr ukernel;
+ };
CpuGemmMatrixAdditionKernel() = default;
ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuGemmMatrixAdditionKernel);
/** Initialise the kernel's input and output.
@@ -69,6 +78,8 @@ public:
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
const char *name() const override;
+ static const std::vector<GemmMatrixAddKernel> &get_available_kernels();
+
private:
/** Common signature for all the matrix addition functions
*
@@ -77,10 +88,9 @@ private:
* @param[in] window Region on which to execute the kernel.
* @param[in] beta Weight of matrix C
*/
- using MatrixAdditionFunctionPtr = void (*)(const ITensor *src, ITensor *dst, const Window &window, float beta);
/** Matrix addition function to use for the particular tensor types passed to configure() */
- MatrixAdditionFunctionPtr _func{ nullptr };
- float _beta{ 0.f };
+ GemmMatrixAddKernelPtr _func{ nullptr };
+ float _beta{ 0.f };
};
} // namespace kernels
} // namespace cpu