aboutsummaryrefslogtreecommitdiff
path: root/examples/gemm_tuner/cl_gemm_reshaped.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/gemm_tuner/cl_gemm_reshaped.cpp')
-rw-r--r--examples/gemm_tuner/cl_gemm_reshaped.cpp8
1 files changed, 8 insertions, 0 deletions
diff --git a/examples/gemm_tuner/cl_gemm_reshaped.cpp b/examples/gemm_tuner/cl_gemm_reshaped.cpp
index 6445592a72..e579ed762c 100644
--- a/examples/gemm_tuner/cl_gemm_reshaped.cpp
+++ b/examples/gemm_tuner/cl_gemm_reshaped.cpp
@@ -27,6 +27,7 @@
#include "CommonGemmExampleOptions.h"
#include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h"
+#include "arm_compute/core/CL/kernels/CLGEMMReshapeLHSMatrixKernel.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/KernelDescriptors.h"
#include "arm_compute/core/Types.h"
@@ -165,6 +166,8 @@ GemmConfigs consume_gemm_configs(const GemmConfigOptions &options)
}
} // namespace
+// Create function for CLGEMMReshapeLHSMatrixKernel
+using CLGEMMReshapeLHSMatrix = test::CLSynthetizeFunction<CLGEMMReshapeLHSMatrixKernel>;
// Create function for CLGEMMMatrixMultiplyReshapedKernel
using CLGEMMMatrixMultiplyReshaped = test::CLSynthetizeFunction<CLGEMMMatrixMultiplyReshapedKernel>;
@@ -249,6 +252,9 @@ public:
// Initialise rhs_reshaped tensor info
auto_init_if_empty(*rhs_reshaped.info(), rhs.info()->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*rhs.info(), rhs_info)));
+ // Configure reshape lhs function
+ reshape_lhs.configure(&lhs, &lhs_reshaped, lhs_info);
+
// Configure function
gemm.configure(&lhs_reshaped, &rhs_reshaped, &bias, &dst, alpha, beta, lhs_info, rhs_info, kernel_info);
@@ -265,6 +271,7 @@ public:
void do_run() override
{
// Execute the function
+ reshape_lhs.run();
gemm.run();
// Make sure all the OpenCL jobs are done executing:
@@ -283,6 +290,7 @@ private:
CLTensor bias{};
CLTensor dst{};
CLTuner tuner{};
+ CLGEMMReshapeLHSMatrix reshape_lhs{};
CLGEMMMatrixMultiplyReshaped gemm{};
};