diff options
Diffstat (limited to 'examples/gemm_tuner/cl_gemm_reshaped.cpp')
-rw-r--r-- | examples/gemm_tuner/cl_gemm_reshaped.cpp | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/examples/gemm_tuner/cl_gemm_reshaped.cpp b/examples/gemm_tuner/cl_gemm_reshaped.cpp index 6445592a72..e579ed762c 100644 --- a/examples/gemm_tuner/cl_gemm_reshaped.cpp +++ b/examples/gemm_tuner/cl_gemm_reshaped.cpp @@ -27,6 +27,7 @@ #include "CommonGemmExampleOptions.h" #include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h" +#include "arm_compute/core/CL/kernels/CLGEMMReshapeLHSMatrixKernel.h" #include "arm_compute/core/Helpers.h" #include "arm_compute/core/KernelDescriptors.h" #include "arm_compute/core/Types.h" @@ -165,6 +166,8 @@ GemmConfigs consume_gemm_configs(const GemmConfigOptions &options) } } // namespace +// Create function for CLGEMMReshapeLHSMatrixKernel +using CLGEMMReshapeLHSMatrix = test::CLSynthetizeFunction<CLGEMMReshapeLHSMatrixKernel>; // Create function for CLGEMMMatrixMultiplyReshapedKernel using CLGEMMMatrixMultiplyReshaped = test::CLSynthetizeFunction<CLGEMMMatrixMultiplyReshapedKernel>; @@ -249,6 +252,9 @@ public: // Initialise rhs_reshaped tensor info auto_init_if_empty(*rhs_reshaped.info(), rhs.info()->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*rhs.info(), rhs_info))); + // Configure reshape lhs function + reshape_lhs.configure(&lhs, &lhs_reshaped, lhs_info); + // Configure function gemm.configure(&lhs_reshaped, &rhs_reshaped, &bias, &dst, alpha, beta, lhs_info, rhs_info, kernel_info); @@ -265,6 +271,7 @@ public: void do_run() override { // Execute the function + reshape_lhs.run(); gemm.run(); // Make sure all the OpenCL jobs are done executing: @@ -283,6 +290,7 @@ private: CLTensor bias{}; CLTensor dst{}; CLTuner tuner{}; + CLGEMMReshapeLHSMatrix reshape_lhs{}; CLGEMMMatrixMultiplyReshaped gemm{}; }; |