aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h')
-rw-r--r--src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h23
1 files changed, 17 insertions, 6 deletions
diff --git a/src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h b/src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h
index a7dfec87bd..765fcb8275 100644
--- a/src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h
+++ b/src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h
@@ -42,7 +42,8 @@ namespace kernels
class CpuGemmMatrixMultiplyKernel : public ICpuKernel<CpuGemmMatrixMultiplyKernel>
{
private:
- using GemmMatrixMulKernelPtr = std::add_pointer<void(const ITensor *, const ITensor *, ITensor *, const Window &, const ThreadInfo &, float, const bool)>::type;
+ using GemmMatrixMulKernelPtr = std::add_pointer<void(
+ const ITensor *, const ITensor *, ITensor *, const Window &, const ThreadInfo &, float, const bool)>::type;
public:
struct GemmMatrixMulKernel
@@ -67,17 +68,27 @@ public:
* @param[in] is_interleaved (Optional) True if lhs and rhs have been reshaped respectively using @ref CpuGemmInterleave4x4Kernel and @ref CpuGemmTranspose1xWKernel
* @param[in] reshape_info (Optional) GEMM reshape info. If is_interleaved_transposed = true, this object must contain the information to understand how @p lhs and @p rhs have been reshaped
*/
- void configure(const ITensorInfo *lhs, const ITensorInfo *rhs, ITensorInfo *dst, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info = GEMMReshapeInfo());
+ void configure(const ITensorInfo *lhs,
+ const ITensorInfo *rhs,
+ ITensorInfo *dst,
+ float alpha,
+ bool is_interleaved,
+ const GEMMReshapeInfo &reshape_info = GEMMReshapeInfo());
/** Static function to check if given info will lead to a valid configuration of @ref CpuGemmMatrixMultiplyKernel
*
* Similar to @ref CpuGemmMatrixMultiplyKernel::configure()
*
* @return a status
*/
- static Status validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info);
+ static Status validate(const ITensorInfo *lhs,
+ const ITensorInfo *rhs,
+ const ITensorInfo *dst,
+ float alpha,
+ bool is_interleaved,
+ const GEMMReshapeInfo &reshape_info);
// Inherited methods overridden:
- void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
+ void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
const char *name() const override;
static const std::vector<GemmMatrixMulKernel> &get_available_kernels();
@@ -94,8 +105,8 @@ private:
*/
/** Matrix multiply function to use for the particular tensor types passed to configure() */
- GemmMatrixMulKernelPtr _func{ nullptr };
- float _alpha{ 1.f };
+ GemmMatrixMulKernelPtr _func{nullptr};
+ float _alpha{1.f};
};
} // namespace kernels
} // namespace cpu