aboutsummaryrefslogtreecommitdiff
path: root/examples/gemm_tuner/cl_gemm_reshaped.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/gemm_tuner/cl_gemm_reshaped.cpp')
-rw-r--r--examples/gemm_tuner/cl_gemm_reshaped.cpp35
1 files changed, 22 insertions, 13 deletions
diff --git a/examples/gemm_tuner/cl_gemm_reshaped.cpp b/examples/gemm_tuner/cl_gemm_reshaped.cpp
index 444a342d74..e6caeec873 100644
--- a/examples/gemm_tuner/cl_gemm_reshaped.cpp
+++ b/examples/gemm_tuner/cl_gemm_reshaped.cpp
@@ -33,8 +33,8 @@
#include "arm_compute/runtime/CL/CLTuner.h"
#include "examples/gemm_tuner/CommonGemmExampleOptions.h"
#include "examples/gemm_tuner/GemmTunerHelpers.h"
-#include "src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h"
-#include "src/core/CL/kernels/CLGEMMReshapeLHSMatrixKernel.h"
+#include "src/core/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.h"
+#include "src/core/gpu/cl/kernels/ClGemmReshapeLhsMatrixKernel.h"
#include "tests/CL/Helper.h"
#include "utils/Utils.h"
#include "utils/command_line/CommandLineOptions.h"
@@ -43,6 +43,7 @@
#include <cstdlib>
using namespace arm_compute;
+using namespace arm_compute::opencl::kernels;
using namespace utils;
using namespace arm_compute::misc::shape_calculator;
using namespace gemm_tuner;
@@ -172,10 +173,11 @@ GemmConfigs consume_gemm_configs(const GemmConfigOptions &options)
}
} // namespace
-// Create function for CLGEMMReshapeLHSMatrixKernel
-using CLGEMMReshapeLHSMatrix = test::CLSynthetizeFunction<CLGEMMReshapeLHSMatrixKernel>;
-// Create function for CLGEMMMatrixMultiplyReshapedKernel
-using CLGEMMMatrixMultiplyReshaped = test::CLSynthetizeFunction<CLGEMMMatrixMultiplyReshapedKernel>;
+
+// Create function for ClGemmReshapeLhsMatrixKernel
+using CLGEMMReshapeLHSMatrix = test::CLSynthetizeOperator<ClGemmReshapeLhsMatrixKernel>;
+// Create function for ClGemmMatrixMultiplyReshapedKernel
+using CLGEMMMatrixMultiplyReshaped = test::CLSynthetizeOperator<ClGemmMatrixMultiplyReshapedKernel>;
class CLGEMMMatrixMultiplyReshapedExample : public Example
{
@@ -271,7 +273,7 @@ public:
// Validate argments
Status status{};
- status = reshape_lhs.validate((&lhs)->info(), (&lhs_reshaped)->info(), lhs_info, kernel_info.reinterpret_input_as_3d);
+ status = reshape_lhs.validate(lhs.info(), lhs_reshaped.info(), lhs_info, kernel_info.reinterpret_input_as_3d);
if(!status)
{
// Unsupported arguments
@@ -280,7 +282,7 @@ public:
return false;
}
- status = gemm.validate((&lhs_reshaped)->info(), (&rhs_reshaped)->info(), (&bias)->info(), (&dst)->info(), alpha, beta, lhs_info, rhs_info, kernel_info);
+ status = gemm.validate(lhs_reshaped.info(), rhs_reshaped.info(), bias.info(), dst.info(), alpha, beta, lhs_info, rhs_info, kernel_info);
if(!status)
{
// Unsupported arguments
@@ -290,10 +292,10 @@ public:
}
// Configure reshape lhs function
- reshape_lhs.configure(&lhs, &lhs_reshaped, lhs_info);
+ reshape_lhs.configure(lhs.info(), lhs_reshaped.info(), lhs_info);
// Configure function
- gemm.configure(&lhs_reshaped, &rhs_reshaped, &bias, &dst, alpha, beta, lhs_info, rhs_info, kernel_info);
+ gemm.configure(lhs_reshaped.info(), rhs_reshaped.info(), bias.info(), dst.info(), alpha, beta, lhs_info, rhs_info, kernel_info);
// Allocate tensors
lhs.allocator()->allocate();
@@ -307,9 +309,16 @@ public:
}
void do_run() override
{
- // Execute the function
- reshape_lhs.run();
- gemm.run();
+ // Execute the functions
+ ITensorPack reshape_lsh_pack({ { ACL_SRC, &lhs }, { ACL_DST, &lhs_reshaped } });
+ reshape_lhs.run(reshape_lsh_pack);
+
+ ITensorPack gemm_pack({ { ACL_SRC_0, &lhs_reshaped },
+ { ACL_SRC_1, &rhs_reshaped },
+ { ACL_SRC_2, &bias },
+ { ACL_DST, &dst }
+ });
+ reshape_lhs.run(gemm_pack);
// Make sure all the OpenCL jobs are done executing:
CLScheduler::get().sync();