aboutsummaryrefslogtreecommitdiff
path: root/src/gpu/cl/operators/ClGemm.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu/cl/operators/ClGemm.h')
-rw-r--r--src/gpu/cl/operators/ClGemm.h89
1 files changed, 77 insertions, 12 deletions
diff --git a/src/gpu/cl/operators/ClGemm.h b/src/gpu/cl/operators/ClGemm.h
index 11f9f2b3d8..85dc1d6c8f 100644
--- a/src/gpu/cl/operators/ClGemm.h
+++ b/src/gpu/cl/operators/ClGemm.h
@@ -90,30 +90,95 @@ public:
* if the reshape of matrix B should happen only for the first run. GEMMInfo also contains information about the reshaping
* in case matrix A and matrix B have been already transformed.
*/
- void configure(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
+ void configure(const CLCompileContext &compile_context,
+ ITensorInfo *a,
+ ITensorInfo *b,
+ ITensorInfo *c,
+ ITensorInfo *output,
+ float alpha,
+ float beta,
+ const GEMMInfo &gemm_info);
/** Static function to check if given info will lead to a valid configuration
*
* Similar to ClGemm::configure()
*
* @return a status
*/
- static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
+ static Status validate(const ITensorInfo *a,
+ const ITensorInfo *b,
+ const ITensorInfo *c,
+ const ITensorInfo *output,
+ float alpha,
+ float beta,
+ const GEMMInfo &gemm_info);
// Inherited methods overridden:
- void run(ITensorPack &tensors) override;
- void prepare(ITensorPack &constants) override;
+ void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &constants) override;
experimental::MemoryRequirements workspace() const override;
private:
- void configure_native(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
- void configure_reshaped(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
- void configure_reshaped_only_rhs(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
- void configure_reshaped_only_rhs_mmul(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
+ void configure_native(const CLCompileContext &compile_context,
+ ITensorInfo *a,
+ ITensorInfo *b,
+ ITensorInfo *c,
+ ITensorInfo *output,
+ float alpha,
+ float beta,
+ const GEMMInfo &gemm_info);
+ void configure_reshaped(const CLCompileContext &compile_context,
+ ITensorInfo *a,
+ ITensorInfo *b,
+ ITensorInfo *c,
+ ITensorInfo *output,
+ float alpha,
+ float beta,
+ const GEMMInfo &gemm_info);
+ void configure_reshaped_only_rhs(const CLCompileContext &compile_context,
+ ITensorInfo *a,
+ ITensorInfo *b,
+ ITensorInfo *c,
+ ITensorInfo *output,
+ float alpha,
+ float beta,
+ const GEMMInfo &gemm_info);
+ void configure_reshaped_only_rhs_mmul(const CLCompileContext &compile_context,
+ ITensorInfo *a,
+ ITensorInfo *b,
+ ITensorInfo *c,
+ ITensorInfo *output,
+ float alpha,
+ float beta,
+ const GEMMInfo &gemm_info);
- static Status validate_native(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
- static Status validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
- static Status validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
- static Status validate_reshaped_only_rhs_mmul(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
+ static Status validate_native(const ITensorInfo *a,
+ const ITensorInfo *b,
+ const ITensorInfo *c,
+ const ITensorInfo *output,
+ float alpha,
+ float beta,
+ const GEMMInfo &gemm_info);
+ static Status validate_reshaped(const ITensorInfo *a,
+ const ITensorInfo *b,
+ const ITensorInfo *c,
+ const ITensorInfo *output,
+ float alpha,
+ float beta,
+ const GEMMInfo &gemm_info);
+ static Status validate_reshaped_only_rhs(const ITensorInfo *a,
+ const ITensorInfo *b,
+ const ITensorInfo *c,
+ const ITensorInfo *output,
+ float alpha,
+ float beta,
+ const GEMMInfo &gemm_info);
+ static Status validate_reshaped_only_rhs_mmul(const ITensorInfo *a,
+ const ITensorInfo *b,
+ const ITensorInfo *c,
+ const ITensorInfo *output,
+ float alpha,
+ float beta,
+ const GEMMInfo &gemm_info);
private:
enum AuxTensorIdx