diff options
Diffstat (limited to 'src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h')
-rw-r--r-- | src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h | 23 |
1 files changed, 17 insertions, 6 deletions
diff --git a/src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h b/src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h index a7dfec87bd..765fcb8275 100644 --- a/src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h +++ b/src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h @@ -42,7 +42,8 @@ namespace kernels class CpuGemmMatrixMultiplyKernel : public ICpuKernel<CpuGemmMatrixMultiplyKernel> { private: - using GemmMatrixMulKernelPtr = std::add_pointer<void(const ITensor *, const ITensor *, ITensor *, const Window &, const ThreadInfo &, float, const bool)>::type; + using GemmMatrixMulKernelPtr = std::add_pointer<void( + const ITensor *, const ITensor *, ITensor *, const Window &, const ThreadInfo &, float, const bool)>::type; public: struct GemmMatrixMulKernel @@ -67,17 +68,27 @@ public: * @param[in] is_interleaved (Optional) True if lhs and rhs have been reshaped respectively using @ref CpuGemmInterleave4x4Kernel and @ref CpuGemmTranspose1xWKernel * @param[in] reshape_info (Optional) GEMM reshape info. If is_interleaved_transposed = true, this object must contain the information to understand how @p lhs and @p rhs have been reshaped */ - void configure(const ITensorInfo *lhs, const ITensorInfo *rhs, ITensorInfo *dst, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info = GEMMReshapeInfo()); + void configure(const ITensorInfo *lhs, + const ITensorInfo *rhs, + ITensorInfo *dst, + float alpha, + bool is_interleaved, + const GEMMReshapeInfo &reshape_info = GEMMReshapeInfo()); /** Static function to check if given info will lead to a valid configuration of @ref CpuGemmMatrixMultiplyKernel * * Similar to @ref CpuGemmMatrixMultiplyKernel::configure() * * @return a status */ - static Status validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info); + static Status validate(const ITensorInfo *lhs, + const ITensorInfo *rhs, + const ITensorInfo *dst, + float alpha, + bool is_interleaved, + const GEMMReshapeInfo &reshape_info); // Inherited methods overridden: - void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; + void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; const char *name() const override; static const std::vector<GemmMatrixMulKernel> &get_available_kernels(); @@ -94,8 +105,8 @@ private: */ /** Matrix multiply function to use for the particular tensor types passed to configure() */ - GemmMatrixMulKernelPtr _func{ nullptr }; - float _alpha{ 1.f }; + GemmMatrixMulKernelPtr _func{nullptr}; + float _alpha{1.f}; }; } // namespace kernels } // namespace cpu |