aboutsummaryrefslogtreecommitdiff
path: root/examples/gemm_tuner/cl_gemm_reshaped_rhs_only.cpp
diff options
context:
space:
mode:
authorSiCong Li <sicong.li@arm.com>2019-09-24 15:50:34 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2019-09-27 09:31:07 +0000
commit240b79de1c211ebb8d439b4a1c8c79777aa36f13 (patch)
treec468b1d8f66699e90c728b650c2c6959aef6b2d5 /examples/gemm_tuner/cl_gemm_reshaped_rhs_only.cpp
parentbc88c621bd89cb93b503e2d8386383ff9ae6412a (diff)
downloadComputeLibrary-240b79de1c211ebb8d439b4a1c8c79777aa36f13.tar.gz
COMPMID-2692 Use existing command line infrastructure for
CLGEMMReshapedOnlyRHS * Refactor to use existing command line infrastructure. * Fix: Remove errorneous initialisation code for dst matrix. * Fix: Correctly set bias tensor to be a vector. Change-Id: I787bfa08392df806aba3b0be09bab015f16010f7 Signed-off-by: SiCong Li <sicong.li@arm.com> Reviewed-on: https://review.mlplatform.org/c/1985 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Diffstat (limited to 'examples/gemm_tuner/cl_gemm_reshaped_rhs_only.cpp')
-rw-r--r--examples/gemm_tuner/cl_gemm_reshaped_rhs_only.cpp154
1 files changed, 97 insertions, 57 deletions
diff --git a/examples/gemm_tuner/cl_gemm_reshaped_rhs_only.cpp b/examples/gemm_tuner/cl_gemm_reshaped_rhs_only.cpp
index 5075c1daf9..0d161aab2d 100644
--- a/examples/gemm_tuner/cl_gemm_reshaped_rhs_only.cpp
+++ b/examples/gemm_tuner/cl_gemm_reshaped_rhs_only.cpp
@@ -25,6 +25,7 @@
#error "This example needs to be built with -DARM_COMPUTE_CL"
#endif /* ARM_COMPUTE_CL */
+#include "CommonGemmExampleOptions.h"
#include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/KernelDescriptors.h"
@@ -35,49 +36,27 @@
#include "arm_compute/runtime/CL/CLTuner.h"
#include "tests/CL/Helper.h"
#include "utils/Utils.h"
+#include "utils/command_line/CommandLineOptions.h"
+#include "utils/command_line/CommandLineParser.h"
#include <cstdlib>
using namespace arm_compute;
using namespace utils;
using namespace arm_compute::misc::shape_calculator;
+using namespace gemm_tuner;
namespace
{
-/** Structure holding all the common gemm example parameters */
-struct CommonGemmExampleParams
-{
- size_t M{ 100 };
- size_t N{ 100 };
- size_t K{ 50 };
- size_t B{ 1 };
-};
-
-/** Formatted output of the CommonGemmExampleParams type
- *
- * @param[out] os Output stream.
- * @param[in] common_params Common parameters to output
- *
- * @return Modified output stream.
- */
-::std::ostream &operator<<(::std::ostream &os, const CommonGemmExampleParams &common_params)
-{
- os << "M : " << common_params.M << std::endl;
- os << "N : " << common_params.N << std::endl;
- os << "K : " << common_params.K << std::endl;
- os << "B : " << common_params.B << std::endl;
- return os;
-}
-
/** Structure holding all tunable gemm configs specific to this example/strategy */
struct GemmConfigs
{
- size_t m0{ 4 };
- size_t n0{ 4 };
- size_t k0{ 4 };
- size_t h0{ 1 };
- bool interleave_rhs{ true };
- bool transpose_rhs{ true };
+ size_t m0{ 4 }; /**< Number of rows processed by the matrix multiplication */
+ size_t n0{ 4 }; /**< Number of columns processed by the matrix multiplication */
+ size_t k0{ 4 }; /**< Number of partial accumulations performed by the matrix multiplication */
+ size_t h0{ 1 }; /**< Number of horizontal blocks of size (k0xn0) stored on the same output row */
+ bool interleave_rhs{ true }; /**< Interleave rhs matrix */
+ bool transpose_rhs{ true }; /**< Transpose rhs matrix */
};
/** Formatted output of the GemmConfigs type
@@ -100,6 +79,67 @@ struct GemmConfigs
os << "transpose_rhs : " << (configs.transpose_rhs ? true_str : false_str) << std::endl;
return os;
}
+
+/** Command line options for gemm configs */
+class GemmConfigOptions
+{
+public:
+ /** Constructor
+ *
+ * @param[in,out] parser A parser on which "parse()" hasn't been called yet.
+ */
+ GemmConfigOptions(CommandLineParser &parser)
+ : m0(parser.add_positional_option<SimpleOption<size_t>>("m0", 4)),
+ n0(parser.add_positional_option<SimpleOption<size_t>>("n0", 4)),
+ k0(parser.add_positional_option<SimpleOption<size_t>>("k0", 4)),
+ h0(parser.add_positional_option<SimpleOption<size_t>>("h0", 1)),
+ interleave_rhs(parser.add_positional_option<SimpleOption<size_t>>("interleave_rhs", 1)),
+ transpose_rhs(parser.add_positional_option<SimpleOption<size_t>>("transpose_rhs", 1))
+ {
+ m0->set_help("Number of rows processed by the matrix multiplication");
+ n0->set_help("Number of columns processed by the matrix multiplication");
+ k0->set_help("Number of partial accumulations performed by the matrix multiplication");
+ h0->set_help("Number of horizontal blocks of size (k0xn0) stored on the same output row");
+ interleave_rhs->set_help("Interleave rhs matrix (1) / Do not interleave rhs matrix (0)");
+ transpose_rhs->set_help("Transpose rhs matrix (1) / Do not transpose rhs matrix (0)");
+ }
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ GemmConfigOptions(const GemmConfigOptions &) = delete;
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ GemmConfigOptions &operator=(const GemmConfigOptions &) = delete;
+ /** Allow instances of this class to be moved */
+ GemmConfigOptions(GemmConfigOptions &&) = default;
+ /** Allow instances of this class to be moved */
+ GemmConfigOptions &operator=(GemmConfigOptions &&) = default;
+ /** Default destructor */
+ ~GemmConfigOptions() = default;
+
+ SimpleOption<size_t> *m0; /**< Number of rows processed by the matrix multiplication option */
+ SimpleOption<size_t> *n0; /**< Number of columns processed by the matrix multiplication option */
+ SimpleOption<size_t> *k0; /**< Number of partial accumulations performed by the matrix multiplication option */
+ SimpleOption<size_t> *h0; /**< Number of horizontal blocks of size (k0xn0) stored on the same output row option */
+ SimpleOption<size_t> *interleave_rhs; /**< Interleave rhs matrix option (1 enable; 0 disable) */
+ SimpleOption<size_t> *transpose_rhs; /**< Transpose rhs matrix option (1 enable; 0 disable) */
+};
+
+/** Consumes the gemm configuration options and creates a structure containing all information
+ *
+ * @param[in] options Options to consume
+ *
+ * @return Structure containing the gemm configurations
+ */
+GemmConfigs consume_gemm_configs(const GemmConfigOptions &options)
+{
+ GemmConfigs configs;
+ configs.m0 = options.m0->value();
+ configs.n0 = options.n0->value();
+ configs.k0 = options.k0->value();
+ configs.h0 = options.h0->value();
+ configs.interleave_rhs = options.interleave_rhs->value() != 0;
+ configs.transpose_rhs = options.transpose_rhs->value() != 0;
+ return configs;
+}
+
} // namespace
// Create function for CLGEMMMatrixMultiplyReshapedOnlyRHSKernel
using CLGEMMMatrixMultiplyReshapedOnlyRHS = test::CLSynthetizeFunction<CLGEMMMatrixMultiplyReshapedOnlyRHSKernel>;
@@ -116,33 +156,35 @@ public:
const ActivationLayerInfo act_info = ActivationLayerInfo();
CommonGemmExampleParams params;
GemmConfigs configs;
- if(argc < 9 || argc > 11)
+
+ // Set up command line parser and options
+ CommandLineParser parser;
+ CommonGemmExampleOptions param_options(parser);
+ GemmConfigOptions config_options(parser);
+
+ // Parse command line options
+ parser.parse(argc, argv);
+ if(param_options.help->is_set() && param_options.help->value())
{
- // Print help
- // Use default parameters
- std::cerr << "Usage: ./build/cl_gemm_reshaped_rhs_only M N K B m0 n0 k0 h0 [interleave_rhs = 1] [transpose_rhs = 1]\n\n";
+ // Print help message
+ parser.print_help(argv[0]);
+ return false;
+ }
+ if(!parser.validate())
+ {
+ // Invalid arguments. Use default parameters and configs
+ std::cerr << "Invalid arguments." << std::endl;
+ parser.print_help(argv[0]);
std::cerr << "Falling back to default parameters and configs" << std::endl;
}
else
{
- // Set parameters from command line arguments
- params.M = strtol(argv[1], nullptr, 10);
- params.N = strtol(argv[2], nullptr, 10);
- params.K = strtol(argv[3], nullptr, 10);
- params.B = strtol(argv[4], nullptr, 10);
- configs.m0 = strtol(argv[5], nullptr, 10);
- configs.n0 = strtol(argv[6], nullptr, 10);
- configs.k0 = strtol(argv[7], nullptr, 10);
- configs.h0 = strtol(argv[8], nullptr, 10);
- if(argc > 9)
- {
- configs.interleave_rhs = strtol(argv[9], nullptr, 10) == 1;
- }
- if(argc > 10)
- {
- configs.transpose_rhs = strtol(argv[10], nullptr, 10) == 1;
- }
+ // Get parameters and configs from command-line options
+ params = consume_common_gemm_example_parameters(param_options);
+ configs = consume_gemm_configs(config_options);
}
+
+ // Print gemm parameters and configurations
std::cerr << "Gemm parameters:" << std::endl;
std::cerr << params << std::endl;
std::cerr << "Gemm configurations:" << std::endl;
@@ -152,9 +194,7 @@ public:
lhs.allocator()->init(TensorInfo(TensorShape(params.K, params.M, params.B), 1, data_type));
rhs.allocator()->init(TensorInfo(TensorShape(params.N, params.K, params.B), 1, data_type));
- bias.allocator()->init(TensorInfo(TensorShape(params.N, params.M, params.B), 1, data_type));
-
- init_sgemm_output(dst, lhs, rhs, data_type);
+ bias.allocator()->init(TensorInfo(TensorShape(params.N, 1, params.B), 1, data_type));
GEMMLHSMatrixInfo lhs_info;
lhs_info.m0 = configs.m0;
@@ -217,9 +257,9 @@ private:
/** Main program for gemm reshaped rhs only test
*
* @param[in] argc Number of arguments
- * @param[in] argv Arguments ( M, N, K, B, m0, n0, k0, h0, [optional] interleave_rhs, [optional] transpose_rhs )
+ * @param[in] argv Arguments ( [optional] M, [optional] N, [optional] K, [optional] B, [optional] m0, [optional] n0, [optional] k0, [optional] h0, [optional] interleave_rhs, [optional] transpose_rhs )
*/
int main(int argc, char **argv)
{
- return utils::run_example<CLGEMMMatrixMultiplyReshapedOnlyRHSExample>(argc, argv);
+ return run_example<CLGEMMMatrixMultiplyReshapedOnlyRHSExample>(argc, argv);
}