aboutsummaryrefslogtreecommitdiff
path: root/examples/gemm_tuner/cl_gemm_native.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/gemm_tuner/cl_gemm_native.cpp')
-rw-r--r--examples/gemm_tuner/cl_gemm_native.cpp65
1 files changed, 40 insertions, 25 deletions
diff --git a/examples/gemm_tuner/cl_gemm_native.cpp b/examples/gemm_tuner/cl_gemm_native.cpp
index 0cacd82087..7daa0b07d3 100644
--- a/examples/gemm_tuner/cl_gemm_native.cpp
+++ b/examples/gemm_tuner/cl_gemm_native.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2019-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -25,23 +25,24 @@
#error "This example needs to be built with -DARM_COMPUTE_CL"
#endif /* ARM_COMPUTE_CL */
-#include "CommonGemmExampleOptions.h"
-#include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/KernelDescriptors.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
-#include "arm_compute/runtime/CL/CLFunctions.h"
#include "arm_compute/runtime/CL/CLScheduler.h"
#include "arm_compute/runtime/CL/CLTuner.h"
+
+#include "src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.h"
#include "tests/CL/Helper.h"
-#include "utils/Utils.h"
#include "utils/command_line/CommandLineOptions.h"
#include "utils/command_line/CommandLineParser.h"
+#include "utils/Utils.h"
+#include "CommonGemmExampleOptions.h"
#include <cstdlib>
using namespace arm_compute;
+using namespace arm_compute::opencl::kernels;
using namespace utils;
using namespace arm_compute::misc::shape_calculator;
using namespace gemm_tuner;
@@ -51,9 +52,9 @@ namespace
/** Structure holding all tunable gemm configs specific to this example/strategy */
struct GemmConfigs
{
- size_t m0{ 4 }; /**< Number of rows processed by the matrix multiplication */
- size_t n0{ 4 }; /**< Number of columns processed by the matrix multiplication */
- size_t k0{ 4 }; /**< Number of partial accumulations performed by the matrix multiplication */
+ size_t m0{4}; /**< Number of rows processed by the matrix multiplication */
+ size_t n0{4}; /**< Number of columns processed by the matrix multiplication */
+ size_t k0{4}; /**< Number of partial accumulations performed by the matrix multiplication */
};
/** Formatted output of the GemmConfigs type
@@ -123,8 +124,8 @@ GemmConfigs consume_gemm_configs(const GemmConfigOptions &options)
}
} // namespace
-// Create function for CLGEMMMatrixMultiplyNativeKernel
-using CLGEMMMatrixMultiplyNative = test::CLSynthetizeFunction<CLGEMMMatrixMultiplyNativeKernel>;
+// Create function for ClGemmMatrixMultiplyNativeKernel
+using CLGEMMMatrixMultiplyNative = test::CLSynthetizeOperator<ClGemmMatrixMultiplyNativeKernel>;
class CLGEMMMatrixMultiplyNativeExample : public Example
{
@@ -132,10 +133,9 @@ public:
bool do_setup(int argc, char **argv) override
{
// Default parameters
- const DataType data_type = DataType::F32;
- const float alpha = 1.0f;
- const float beta = 0.0f;
- const ActivationLayerInfo act_info = ActivationLayerInfo();
+ const float alpha = 1.0f;
+ const float beta = 0.0f;
+ const ActivationLayerInfo act_info = ActivationLayerInfo();
CommonGemmExampleParams params;
GemmConfigs configs;
@@ -146,13 +146,13 @@ public:
// Parse command line options
parser.parse(argc, argv);
- if(param_options.help->is_set() && param_options.help->value())
+ if (param_options.help->is_set() && param_options.help->value())
{
// Print help message
parser.print_help(argv[0]);
return false;
}
- if(!parser.validate())
+ if (!parser.validate())
{
// Invalid arguments. Use default parameters and configs
std::cerr << "Invalid arguments." << std::endl;
@@ -167,16 +167,18 @@ public:
}
// Print gemm parameters and configurations
- std::cerr << "Gemm parameters:" << std::endl;
- std::cerr << params << std::endl;
- std::cerr << "Gemm configurations:" << std::endl;
- std::cerr << configs << std::endl;
+ std::cout << "Gemm parameters:" << std::endl;
+ std::cout << params << std::endl;
+ std::cout << "Gemm configurations:" << std::endl;
+ std::cout << configs << std::endl;
+
+ tuner.set_tuner_mode(params.tuner_mode);
CLScheduler::get().default_init(&tuner);
- lhs.allocator()->init(TensorInfo(TensorShape(params.K, params.M, params.B), 1, data_type));
- rhs.allocator()->init(TensorInfo(TensorShape(params.N, params.K, params.B), 1, data_type));
- bias.allocator()->init(TensorInfo(TensorShape(params.N, 1, params.B), 1, data_type));
+ lhs.allocator()->init(TensorInfo(TensorShape(params.K, params.M, params.B), 1, params.data_type));
+ rhs.allocator()->init(TensorInfo(TensorShape(params.N, params.K, params.B), 1, params.data_type));
+ bias.allocator()->init(TensorInfo(TensorShape(params.N, 1, params.B), 1, params.data_type));
GEMMLHSMatrixInfo lhs_info;
lhs_info.m0 = configs.m0;
@@ -195,8 +197,20 @@ public:
kernel_info.broadcast_bias = true;
kernel_info.activation_info = act_info;
+ // Validate argments
+ Status status{};
+ status = gemm.validate(lhs.info(), rhs.info(), bias.info(), dst.info(), alpha, beta, lhs_info, rhs_info,
+ kernel_info);
+ if (!status)
+ {
+ // Unsupported arguments
+ std::cerr << "Unsupported arguments." << std::endl;
+ std::cerr << "Check documentation for supported/unsupported combinations" << std::endl;
+ return false;
+ }
+
// Configure function
- gemm.configure(&lhs, &rhs, &bias, &dst, alpha, beta, lhs_info, rhs_info, kernel_info);
+ gemm.configure(lhs.info(), rhs.info(), bias.info(), dst.info(), alpha, beta, lhs_info, rhs_info, kernel_info);
// Allocate tensors
lhs.allocator()->allocate();
@@ -209,7 +223,8 @@ public:
void do_run() override
{
// Execute the function
- gemm.run();
+ ITensorPack gemm_pack({{ACL_SRC_0, &lhs}, {ACL_SRC_1, &rhs}, {ACL_SRC_2, &bias}, {ACL_DST, &dst}});
+ gemm.run(gemm_pack);
// Make sure all the OpenCL jobs are done executing:
CLScheduler::get().sync();