aboutsummaryrefslogtreecommitdiff
path: root/examples/gemm_tuner/cl_gemm_reshaped_rhs_only.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/gemm_tuner/cl_gemm_reshaped_rhs_only.cpp')
-rw-r--r--examples/gemm_tuner/cl_gemm_reshaped_rhs_only.cpp18
1 files changed, 12 insertions, 6 deletions
diff --git a/examples/gemm_tuner/cl_gemm_reshaped_rhs_only.cpp b/examples/gemm_tuner/cl_gemm_reshaped_rhs_only.cpp
index 68bec9da6e..dbaaca6048 100644
--- a/examples/gemm_tuner/cl_gemm_reshaped_rhs_only.cpp
+++ b/examples/gemm_tuner/cl_gemm_reshaped_rhs_only.cpp
@@ -33,7 +33,7 @@
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/runtime/CL/CLScheduler.h"
#include "arm_compute/runtime/CL/CLTuner.h"
-#include "src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h"
+#include "src/core/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.h"
#include "tests/CL/Helper.h"
#include "utils/Utils.h"
#include "utils/command_line/CommandLineOptions.h"
@@ -42,6 +42,7 @@
#include <cstdlib>
using namespace arm_compute;
+using namespace arm_compute::opencl::kernels;
using namespace utils;
using namespace arm_compute::misc::shape_calculator;
using namespace gemm_tuner;
@@ -147,8 +148,8 @@ GemmConfigs consume_gemm_configs(const GemmConfigOptions &options)
}
} // namespace
-// Create function for CLGEMMMatrixMultiplyReshapedOnlyRHSKernel
-using CLGEMMMatrixMultiplyReshapedOnlyRHS = test::CLSynthetizeFunction<CLGEMMMatrixMultiplyReshapedOnlyRHSKernel>;
+// Create function for ClGemmMatrixMultiplyReshapedOnlyRhsKernel
+using CLGEMMMatrixMultiplyReshapedOnlyRHS = test::CLSynthetizeOperator<ClGemmMatrixMultiplyReshapedOnlyRhsKernel>;
class CLGEMMMatrixMultiplyReshapedOnlyRHSExample : public Example
{
@@ -238,7 +239,7 @@ public:
// Validate argments
Status status{};
- status = gemm.validate((&lhs)->info(), (&rhs_reshaped)->info(), (&bias)->info(), (&dst)->info(), alpha, beta, lhs_info, rhs_info, kernel_info);
+ status = gemm.validate(lhs.info(), rhs_reshaped.info(), bias.info(), dst.info(), alpha, beta, lhs_info, rhs_info, kernel_info);
if(!status)
{
// Unsupported arguments
@@ -248,7 +249,7 @@ public:
}
// Configure function
- gemm.configure(&lhs, &rhs_reshaped, &bias, &dst, alpha, beta, lhs_info, rhs_info, kernel_info);
+ gemm.configure(lhs.info(), rhs_reshaped.info(), bias.info(), dst.info(), alpha, beta, lhs_info, rhs_info, kernel_info);
// Allocate tensors
lhs.allocator()->allocate();
@@ -262,7 +263,12 @@ public:
void do_run() override
{
// Execute the function
- gemm.run();
+ ITensorPack gemm_pack({ { ACL_SRC_0, &lhs },
+ { ACL_SRC_1, &rhs_reshaped },
+ { ACL_SRC_2, &bias },
+ { ACL_DST, &dst }
+ });
+ gemm.run(gemm_pack);
// Make sure all the OpenCL jobs are done executing:
CLScheduler::get().sync();