aboutsummaryrefslogtreecommitdiff
path: root/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.h')
-rw-r--r--src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.h41
1 files changed, 27 insertions, 14 deletions
diff --git a/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.h b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.h
index 8d25412a40..30928c4e1d 100644
--- a/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.h
+++ b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.h
@@ -24,12 +24,12 @@
#ifndef ACL_SRC_GPU_CL_KERNELS_CLGEMMMATRIXMULTIPLYRESHAPEDKERNEL_H
#define ACL_SRC_GPU_CL_KERNELS_CLGEMMMATRIXMULTIPLYRESHAPEDKERNEL_H
+#include "arm_compute/core/KernelDescriptors.h"
+
#include "src/core/common/Macros.h"
#include "src/gpu/cl/ClCompileContext.h"
#include "src/gpu/cl/IClKernel.h"
-#include "arm_compute/core/KernelDescriptors.h"
-
namespace arm_compute
{
namespace opencl
@@ -83,16 +83,29 @@ public:
*
* @note lhs_info.k0 must be equal to rhs_info.k0
*/
- void configure(const ClCompileContext &compile_context,
- const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, float alpha, float beta,
- const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const GEMMKernelInfo &gemm_info);
+ void configure(const ClCompileContext &compile_context,
+ const ITensorInfo *src0,
+ const ITensorInfo *src1,
+ const ITensorInfo *src2,
+ ITensorInfo *dst,
+ float alpha,
+ float beta,
+ const GEMMLHSMatrixInfo &lhs_info,
+ const GEMMRHSMatrixInfo &rhs_info,
+ const GEMMKernelInfo &gemm_info);
/** Static function to check if given info will lead to a valid configuration
*
* Similar to @ref ClGemmMatrixMultiplyReshapedKernel::configure()
*
* @return a status
*/
- static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info,
+ static Status validate(const ITensorInfo *src0,
+ const ITensorInfo *src1,
+ const ITensorInfo *src2,
+ const ITensorInfo *dst,
+ float alpha,
+ float beta,
+ const GEMMLHSMatrixInfo &lhs_info,
const GEMMRHSMatrixInfo &rhs_info,
const GEMMKernelInfo &gemm_info);
@@ -100,14 +113,14 @@ public:
void run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) override;
private:
- bool _slide_matrix_b{ true };
- bool _reinterpret_output_as_3d{ false };
- bool _use_dummy_work_items{ false };
- bool _add_bias{ false };
- bool _export_to_cl_image{ false };
- signed int _m{ 1 };
- signed int _n{ 1 };
- signed int _k{ 1 };
+ bool _slide_matrix_b{true};
+ bool _reinterpret_output_as_3d{false};
+ bool _use_dummy_work_items{false};
+ bool _add_bias{false};
+ bool _export_to_cl_image{false};
+ signed int _m{1};
+ signed int _n{1};
+ signed int _k{1};
};
} // namespace kernels
} // namespace opencl