diff options
Diffstat (limited to 'src/gpu/cl/operators/ClGemm.h')
-rw-r--r-- | src/gpu/cl/operators/ClGemm.h | 89 |
1 files changed, 77 insertions, 12 deletions
diff --git a/src/gpu/cl/operators/ClGemm.h b/src/gpu/cl/operators/ClGemm.h index 11f9f2b3d8..85dc1d6c8f 100644 --- a/src/gpu/cl/operators/ClGemm.h +++ b/src/gpu/cl/operators/ClGemm.h @@ -90,30 +90,95 @@ public: * if the reshape of matrix B should happen only for the first run. GEMMInfo also contains information about the reshaping * in case matrix A and matrix B have been already transformed. */ - void configure(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); + void configure(const CLCompileContext &compile_context, + ITensorInfo *a, + ITensorInfo *b, + ITensorInfo *c, + ITensorInfo *output, + float alpha, + float beta, + const GEMMInfo &gemm_info); /** Static function to check if given info will lead to a valid configuration * * Similar to ClGemm::configure() * * @return a status */ - static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); + static Status validate(const ITensorInfo *a, + const ITensorInfo *b, + const ITensorInfo *c, + const ITensorInfo *output, + float alpha, + float beta, + const GEMMInfo &gemm_info); // Inherited methods overridden: - void run(ITensorPack &tensors) override; - void prepare(ITensorPack &constants) override; + void run(ITensorPack &tensors) override; + void prepare(ITensorPack &constants) override; experimental::MemoryRequirements workspace() const override; private: - void configure_native(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); - void configure_reshaped(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); - void configure_reshaped_only_rhs(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); - void configure_reshaped_only_rhs_mmul(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); + void configure_native(const CLCompileContext &compile_context, + ITensorInfo *a, + ITensorInfo *b, + ITensorInfo *c, + ITensorInfo *output, + float alpha, + float beta, + const GEMMInfo &gemm_info); + void configure_reshaped(const CLCompileContext &compile_context, + ITensorInfo *a, + ITensorInfo *b, + ITensorInfo *c, + ITensorInfo *output, + float alpha, + float beta, + const GEMMInfo &gemm_info); + void configure_reshaped_only_rhs(const CLCompileContext &compile_context, + ITensorInfo *a, + ITensorInfo *b, + ITensorInfo *c, + ITensorInfo *output, + float alpha, + float beta, + const GEMMInfo &gemm_info); + void configure_reshaped_only_rhs_mmul(const CLCompileContext &compile_context, + ITensorInfo *a, + ITensorInfo *b, + ITensorInfo *c, + ITensorInfo *output, + float alpha, + float beta, + const GEMMInfo &gemm_info); - static Status validate_native(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); - static Status validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); - static Status validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); - static Status validate_reshaped_only_rhs_mmul(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); + static Status validate_native(const ITensorInfo *a, + const ITensorInfo *b, + const ITensorInfo *c, + const ITensorInfo *output, + float alpha, + float beta, + const GEMMInfo &gemm_info); + static Status validate_reshaped(const ITensorInfo *a, + const ITensorInfo *b, + const ITensorInfo *c, + const ITensorInfo *output, + float alpha, + float beta, + const GEMMInfo &gemm_info); + static Status validate_reshaped_only_rhs(const ITensorInfo *a, + const ITensorInfo *b, + const ITensorInfo *c, + const ITensorInfo *output, + float alpha, + float beta, + const GEMMInfo &gemm_info); + static Status validate_reshaped_only_rhs_mmul(const ITensorInfo *a, + const ITensorInfo *b, + const ITensorInfo *c, + const ITensorInfo *output, + float alpha, + float beta, + const GEMMInfo &gemm_info); private: enum AuxTensorIdx |