diff options
Diffstat (limited to 'src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel.h')
-rw-r--r-- | src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel.h | 36 |
1 files changed, 25 insertions, 11 deletions
diff --git a/src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel.h b/src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel.h index 0ae549cd53..fc8b73140d 100644 --- a/src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel.h +++ b/src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel.h @@ -25,6 +25,7 @@ #define ARM_COMPUTE_CL_GEMMLOWP_MATRIXMULTIPLY_RESHAPED_ONLY_RHS_MMUL_KERNEL_H #include "arm_compute/core/KernelDescriptors.h" + #include "src/core/common/Macros.h" #include "src/gpu/cl/IClKernel.h" @@ -65,29 +66,42 @@ public: * @param[in] output_multipliers (Optional) Output multipliers tensor. Supported data types: S32. * @param[in] output_shifts (Optional) Output shifts tensor. Supported data types: S32. */ - void configure(const CLCompileContext &compile_context, const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst, const GEMMKernelInfo &gemm_info, - ITensorInfo *vector_sum_col = nullptr, const ITensorInfo *vector_sum_row = nullptr, ITensorInfo *bias = nullptr, - ITensorInfo *output_multipliers = nullptr, ITensorInfo *output_shifts = nullptr); + void configure(const CLCompileContext &compile_context, + const ITensorInfo *src0, + const ITensorInfo *src1, + ITensorInfo *dst, + const GEMMKernelInfo &gemm_info, + ITensorInfo *vector_sum_col = nullptr, + const ITensorInfo *vector_sum_row = nullptr, + ITensorInfo *bias = nullptr, + ITensorInfo *output_multipliers = nullptr, + ITensorInfo *output_shifts = nullptr); /** Static function to check if given info will lead to a valid configuration * * Similar to @ref ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel::configure() * * @return a status */ - static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst, const GEMMKernelInfo &gemm_info, - const ITensorInfo *vector_sum_col = nullptr, const ITensorInfo *vector_sum_row = nullptr, const ITensorInfo *bias = nullptr, - const ITensorInfo *output_multipliers = nullptr, const ITensorInfo *output_shifts = nullptr); + static Status validate(const ITensorInfo *src0, + const ITensorInfo *src1, + const ITensorInfo *dst, + const GEMMKernelInfo &gemm_info, + const ITensorInfo *vector_sum_col = nullptr, + const ITensorInfo *vector_sum_row = nullptr, + const ITensorInfo *bias = nullptr, + const ITensorInfo *output_multipliers = nullptr, + const ITensorInfo *output_shifts = nullptr); // Inherited methods overridden: void run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) override; private: - bool _fuse_output_stage{ false }; - signed int _m{ 1 }; - signed int _n{ 1 }; - signed int _k{ 1 }; + bool _fuse_output_stage{false}; + signed int _m{1}; + signed int _n{1}; + signed int _k{1}; }; } // namespace kernels } // namespace opencl } // namespace arm_compute -#endif /* ARM_COMPUTE_CL_GEMMLOWP_MATRIXMULTIPLY_RESHAPED_ONLY_RHS_MMULKERNEL_H */
\ No newline at end of file +#endif /* ARM_COMPUTE_CL_GEMMLOWP_MATRIXMULTIPLY_RESHAPED_ONLY_RHS_MMULKERNEL_H */ |