From 108a95e046dde880075b6c278b44033d13f55be3 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Wed, 27 Mar 2019 13:55:59 +0000 Subject: COMPMID-2067: Add batch size option in validate_cl_gemm example -Add batch option -Switches to the command line parsing interface Change-Id: I7c6df752797a37440f980bce6b29e9b5c5244fcb Signed-off-by: Georgios Pinitas Reviewed-on: https://review.mlplatform.org/c/917 Tested-by: Arm Jenkins Reviewed-by: Michalis Spyrou Comments-Addressed: Arm Jenkins --- tests/validate_examples/cl_gemm.cpp | 330 +++++++++++++++++++++--------------- utils/Utils.h | 2 +- 2 files changed, 190 insertions(+), 142 deletions(-) diff --git a/tests/validate_examples/cl_gemm.cpp b/tests/validate_examples/cl_gemm.cpp index 8b3a103db7..4e406cbd9b 100644 --- a/tests/validate_examples/cl_gemm.cpp +++ b/tests/validate_examples/cl_gemm.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -39,9 +39,12 @@ #include "tests/validation/reference/GEMM.h" #include "tests/validation/reference/GEMMLowp.h" -#include "ValidateExample.h" - +#include "utils/TypePrinter.h" #include "utils/Utils.h" +#include "utils/command_line/CommandLineOptions.h" +#include "utils/command_line/CommandLineParser.h" + +#include "ValidateExample.h" #include @@ -56,131 +59,153 @@ RelativeTolerance tolerance_f32(0.001f); /**< F32 Toleran RelativeTolerance tolerance_f16(half(0.2)); /**< F16 Tolerance value for comparing reference's output against implementation's output for floating point data types */ constexpr float tolerance_num_f16 = 0.02f; /**< F16 Tolerance number */ +namespace arm_compute +{ +DataType data_type_from_name(const std::string &name) +{ + static const std::map data_types = + { + { "f16", DataType::F16 }, + { "f32", DataType::F32 }, + { "qasymm8", DataType::QASYMM8 }, + }; + +#ifndef ARM_COMPUTE_EXCEPTIONS_DISABLED + try + { +#endif /* ARM_COMPUTE_EXCEPTIONS_DISABLED */ + return data_types.at(utility::tolower(name)); + +#ifndef ARM_COMPUTE_EXCEPTIONS_DISABLED + } + catch(const std::out_of_range &) + { + throw std::invalid_argument(name); + } +#endif /* ARM_COMPUTE_EXCEPTIONS_DISABLED */ +} + +inline ::std::istream &operator>>(::std::istream &stream, DataType &data_type) +{ + std::string value; + stream >> value; + data_type = data_type_from_name(value); + return stream; +} +} // namespace arm_compute +namespace +{ +class GEMMCommandLineOptions final +{ +public: + explicit GEMMCommandLineOptions(CommandLineParser &parser) noexcept + : help(parser.add_option("help")), + add_bias(parser.add_option("add_bias")), + M(parser.add_option>("m", 7)), + N(parser.add_option>("n", 3)), + K(parser.add_option>("k", 5)), + B(parser.add_option>("b", 1)), + alpha(parser.add_option>("alpha", 1.f)), + beta(parser.add_option>("beta", 0.f)), + offset_src0(parser.add_option>("offset_i0", 10)), + offset_src1(parser.add_option>("offset_i1", 10)), + offset_dst(parser.add_option>("offset_o", 10)), + scale_src0(parser.add_option>("scale_i0", 1.f / 255)), + scale_src1(parser.add_option>("scale_i1", 1.f / 255)), + scale_dst(parser.add_option>("scale_o", 1.f / 255)), + data_type() + { + // Setup data type + const std::set supported_data_types + { + DataType::F16, + DataType::F32, + DataType::QASYMM8, + }; + data_type = parser.add_option>("type", supported_data_types, DataType::F32); + + // Setup help strings + help->set_help("Show this help message"); + add_bias->set_help("Add bias to the GEMM. Used when running in QASYMM8"); + M->set_help("M value"); + N->set_help("N value"); + K->set_help("K value"); + B->set_help("B value - number of batches"); + alpha->set_help("Alpha value"); + beta->set_help("Beta value"); + offset_src0->set_help("Offset of first input. Used when running in QASYMM8"); + offset_src1->set_help("Offset of second input. Used when running in QASYMM8"); + offset_dst->set_help("Offset of output. Used when running in QASYMM8"); + scale_src0->set_help("Scale of first input. Used when running in QASYMM8"); + scale_src1->set_help("Scale of second input. Used when running in QASYMM8"); + scale_dst->set_help("Scale of output. Used when running in QASYMM8"); + data_type->set_help("Data type to use"); + } + /** Prevent instances of this class from being copied (As this class contains pointers) */ + GEMMCommandLineOptions(const GEMMCommandLineOptions &) = delete; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + GEMMCommandLineOptions &operator=(const GEMMCommandLineOptions &) = delete; + /** Allow instances of this class to be moved */ + GEMMCommandLineOptions(GEMMCommandLineOptions &&) noexcept(true) = default; + /** Allow instances of this class to be moved */ + GEMMCommandLineOptions &operator=(GEMMCommandLineOptions &&) noexcept(true) = default; + /** Default destructor */ + ~GEMMCommandLineOptions() = default; + +public: + ToggleOption *help; + ToggleOption *add_bias; + SimpleOption *M; + SimpleOption *N; + SimpleOption *K; + SimpleOption *B; + SimpleOption *alpha; + SimpleOption *beta; + SimpleOption *offset_src0; + SimpleOption *offset_src1; + SimpleOption *offset_dst; + SimpleOption *scale_src0; + SimpleOption *scale_src1; + SimpleOption *scale_dst; + EnumOption *data_type; +}; +} // namespace + class CLGEMMValidateExample : public ValidateExample { public: bool do_setup(int argc, char **argv) override { - //TODO(antbar01): Update to use command line interface ? CLScheduler::get().default_init(); - if(argc == 2) - { - size_t dt = strtol(argv[1], nullptr, 10); - switch(dt) - { - case 1: - { - data_type = DataType::F16; - std::cout << "Usage: " << argv[0] << "1 M N K [alpha = 1.0f] [beta = 0.0f]\n"; - std::cout << "Using default values: Datatype=FP16 M=7, N=3, K=5, alpha=1.0f and beta=0.0f\n"; - break; - } - case 2: - { - data_type = DataType::QASYMM8; - std::cout << "Usage: " << argv[0] << "2 M N K [scale_src0 = 0.1f] [offset_scr0 = f] [scale_scr1 = 0.1f] [offset_scr1 = 10] [scale_dst = 0.1f] [offset_dst = 10] [bias = 1]\n"; - std::cout << - "Using default values: Datatype=QASYMM8 M=7, N=3, K=5, scale_src0 =(1.0f/255), offset_src0 = 10, scale_src1 =(1.0f/255), offset_src1 = 10, scale_dst =(1.0f/255), offset_dst = 10, bias=1\n\n"; - break; - } - case 0: - default: - { - data_type = DataType::F32; - std::cout << "Usage: " << argv[0] << "0 M N K [alpha = 1.0f] [beta = 0.0f]\n"; - std::cout << "Using default values: Datatype=FP32 M=7, N=3, K=5, alpha=1.0f and beta=0.0f\n"; - } - } - } - else if(argc < 5) - { - // Print help - std::cout << "Usage with datatype = FP32 : " << argv[0] << "0 M N K [alpha = 1.0f] [beta = 0.0f]\n"; - std::cout << " datatype = FP16 : " << argv[0] << "1 M N K [alpha = 1.0f] [beta = 0.0f]\n"; - std::cout << " datatype = QASYMM8 : " << argv[0] << "2 M N K [scale_src0 = 0.1f] [offset_scr0 = f] [scale_scr1 = 0.1f] [offset_scr1 = 10] [scale_dst = 0.1f] [offset_dst = 10] [bias = 1]\n"; - std::cout << "Too few or no arguments provided.\n"; - std::cout << "Using default values: Datatype=FP32 M=7, N=3, K=5, alpha=1.0f and beta=0.0f\n"; - } - else + + // Parse options + CommandLineParser parser; + GEMMCommandLineOptions gemm_options(parser); + parser.parse(argc, argv); + + // Print help + const bool print_help = gemm_options.help->is_set() ? gemm_options.help->value() : false; + if(print_help) { - size_t dt = strtol(argv[1], nullptr, 10); - switch(dt) - { - case 1: - { - data_type = DataType::F16; - break; - } - case 2: - { - data_type = DataType::QASYMM8; - break; - } - case 0: - default: - data_type = DataType::F32; - } - M = strtol(argv[2], nullptr, 10); - N = strtol(argv[3], nullptr, 10); - K = strtol(argv[4], nullptr, 10); + parser.print_help(argv[0]); + return false; } - switch(data_type) + // Consume parameters + consume_params(gemm_options); + print_parameters_internal(); + + // Calculate re-quantization parameters + if(data_type == DataType::QASYMM8) { - case DataType::F16: - case DataType::F32: - { - if(argc > 5) - { - alpha = strtof(argv[5], nullptr); - if(argc > 6) - { - beta = strtof(argv[6], nullptr); - } - } - break; - } - case DataType::QASYMM8: - { - if(argc > 5) - { - scale_src0 = strtof(argv[5], nullptr); - if(argc > 6) - { - offset_src0 = strtol(argv[6], nullptr, 10); - if(argc > 7) - { - scale_src1 = strtof(argv[7], nullptr); - if(argc > 8) - { - offset_src1 = strtol(argv[8], nullptr, 10); - if(argc > 9) - { - scale_dst = strtof(argv[9], nullptr); - if(argc > 10) - { - offset_dst = strtol(argv[10], nullptr, 10); - if(argc > 11) - { - add_bias = (strtol(argv[11], nullptr, 10) == 1); - } - } - } - } - } - } - } - float multiplier = scale_src0 * scale_src1 / scale_dst; - quantization::calculate_quantized_multiplier_less_than_one(multiplier, &dst_multiplier, &dst_shift); - break; - } - default: - break; + float multiplier = scale_src0 * scale_src1 / scale_dst; + quantization::calculate_quantized_multiplier_less_than_one(multiplier, &dst_multiplier, &dst_shift); } - src0.allocator()->init(TensorInfo(TensorShape(K, M), 1, data_type)); - src1.allocator()->init(TensorInfo(TensorShape(N, K), 1, data_type)); - src2.allocator()->init(TensorInfo(TensorShape(N, M), 1, data_type)); + // Initialize GEMM inputs/outputs + src0.allocator()->init(TensorInfo(TensorShape(K, M, B), 1, data_type)); + src1.allocator()->init(TensorInfo(TensorShape(N, K, B), 1, data_type)); + src2.allocator()->init(TensorInfo(TensorShape(N, M, B), 1, data_type)); init_sgemm_output(dst, src0, src1, data_type); // Configure function @@ -220,26 +245,27 @@ public: return true; } - void print_parameters(framework::Printer &printer) override + void print_parameters_internal() { - printer.print_entry("Datatype", string_from_data_type(data_type)); - printer.print_entry("M", support::cpp11::to_string(M)); - printer.print_entry("N", support::cpp11::to_string(N)); - printer.print_entry("K", support::cpp11::to_string(K)); + std::cout << "Datatype : " << string_from_data_type(data_type) << "\n"; + std::cout << "M : " << support::cpp11::to_string(M) << "\n"; + std::cout << "N : " << support::cpp11::to_string(N) << "\n"; + std::cout << "K : " << support::cpp11::to_string(K) << "\n"; + std::cout << "B : " << support::cpp11::to_string(B) << "\n"; if(data_type == DataType::QASYMM8) { - printer.print_entry("Scale_Src0", support::cpp11::to_string(scale_src0)); - printer.print_entry("Offset_Src0", support::cpp11::to_string(offset_src0)); - printer.print_entry("Scale_Scr1", support::cpp11::to_string(scale_src1)); - printer.print_entry("Offset_Src1", support::cpp11::to_string(offset_src1)); - printer.print_entry("Scale_Dst", support::cpp11::to_string(scale_dst)); - printer.print_entry("Offset_Dst", support::cpp11::to_string(offset_dst)); - printer.print_entry("Bias", support::cpp11::to_string(add_bias)); + std::cout << "Scale_Src0 : " << support::cpp11::to_string(scale_src0) << "\n"; + std::cout << "Offset_Src0 : " << support::cpp11::to_string(offset_src0) << "\n"; + std::cout << "Scale_Scr1 : " << support::cpp11::to_string(scale_src1) << "\n"; + std::cout << "Offset_Src1 : " << support::cpp11::to_string(offset_src1) << "\n"; + std::cout << "Scale_Dst : " << support::cpp11::to_string(scale_dst) << "\n"; + std::cout << "Offset_Dst : " << support::cpp11::to_string(offset_dst) << "\n"; + std::cout << "Bias : " << support::cpp11::to_string(add_bias) << "\n"; } else { - printer.print_entry("Alpha", support::cpp11::to_string(alpha)); - printer.print_entry("Beta", support::cpp11::to_string(beta)); + std::cout << "Alpha : " << support::cpp11::to_string(alpha) << "\n"; + std::cout << "Beta : " << support::cpp11::to_string(beta) << "\n"; } } @@ -249,9 +275,9 @@ public: { case DataType::F16: { - SimpleTensor ref_src0 = { TensorShape(K, M), data_type, 1 }; - SimpleTensor ref_src1 = { TensorShape(N, K), data_type, 1 }; - SimpleTensor ref_src2 = { TensorShape(N, M), data_type, 1 }; + SimpleTensor ref_src0 = { TensorShape(K, M, B), data_type, 1 }; + SimpleTensor ref_src1 = { TensorShape(N, K, B), data_type, 1 }; + SimpleTensor ref_src2 = { TensorShape(N, M, B), data_type, 1 }; fill(ref_src0, 0); fill(ref_src1, 1); @@ -263,9 +289,9 @@ public: } case DataType::F32: { - SimpleTensor ref_src0 = { TensorShape(K, M), data_type, 1 }; - SimpleTensor ref_src1 = { TensorShape(N, K), data_type, 1 }; - SimpleTensor ref_src2 = { TensorShape(N, M), data_type, 1 }; + SimpleTensor ref_src0 = { TensorShape(K, M, B), data_type, 1 }; + SimpleTensor ref_src1 = { TensorShape(N, K, B), data_type, 1 }; + SimpleTensor ref_src2 = { TensorShape(N, M, B), data_type, 1 }; fill(ref_src0, 0); fill(ref_src1, 1); @@ -277,15 +303,15 @@ public: } case DataType::QASYMM8: { - SimpleTensor ref_src0{ TensorShape(K, M), data_type, 1 }; - SimpleTensor ref_src1{ TensorShape(N, K), data_type, 1 }; + SimpleTensor ref_src0{ TensorShape(K, M, B), data_type, 1 }; + SimpleTensor ref_src1{ TensorShape(N, K, B), data_type, 1 }; SimpleTensor ref_dst; // Fill reference fill(ref_src0, 0); fill(ref_src1, 1); - SimpleTensor ref_tmp_dst = reference::gemmlowp_matrix_multiply_core(ref_src0, ref_src1, TensorShape(N, M), offset_src0, offset_src1); + SimpleTensor ref_tmp_dst = reference::gemmlowp_matrix_multiply_core(ref_src0, ref_src1, TensorShape(N, M, B), offset_src0, offset_src1); if(add_bias) { @@ -350,6 +376,28 @@ private: } } + void consume_params(const GEMMCommandLineOptions &opts) + { + ARM_COMPUTE_ERROR_ON(opts.M->value() <= 0); + ARM_COMPUTE_ERROR_ON(opts.N->value() <= 0); + ARM_COMPUTE_ERROR_ON(opts.K->value() <= 0); + ARM_COMPUTE_ERROR_ON(opts.B->value() <= 0); + M = opts.M->value(); + N = opts.N->value(); + K = opts.K->value(); + B = opts.B->value(); + alpha = opts.alpha->value(); + beta = opts.beta->value(); + offset_src0 = opts.offset_src0->value(); + offset_src1 = opts.offset_src1->value(); + offset_dst = opts.offset_dst->value(); + scale_src0 = opts.scale_src0->value(); + scale_src1 = opts.scale_src1->value(); + scale_dst = opts.scale_dst->value(); + add_bias = opts.add_bias->is_set() ? opts.add_bias->value() : true; + data_type = opts.data_type->value(); + } + CLTensor src0{}, src1{}, src2{}, dst{}; CLTensor tmp_dst{}, biases{}; @@ -357,7 +405,7 @@ private: CLGEMMLowpMatrixMultiplyCore mm_gemmlowp{}; CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint mm_gemmlowp_output_stage{}; - size_t M{ 7 }, N{ 3 }, K{ 5 }; + size_t M{ 7 }, N{ 3 }, K{ 5 }, B{ 1 }; DataType data_type{ DataType::F32 }; float alpha{ 1.0 }, beta{ 0.0 }; int offset_src0{ 10 }, offset_src1{ 10 }, offset_dst{ 10 }; @@ -369,7 +417,7 @@ private: /** Main program for gemm test * * @param[in] argc Number of arguments - * @param[in] argv Arguments ( [optional] datatype, [optional] M, [optional] N, [optional] K, [optional] scale_src0, [optional] offset_src0, [optional] scale_src1, [optional] offset_src1, [optional] scale_dst, [optional] offset_dst, [optional] bias, [optional] alpha, [optional] beta ) + * @param[in] argv Arguments * */ int main(int argc, char **argv) diff --git a/utils/Utils.h b/utils/Utils.h index 788ae4eeb7..afd90a11a3 100644 --- a/utils/Utils.h +++ b/utils/Utils.h @@ -774,7 +774,7 @@ void fill_random_tensor(T &tensor, float lower_bound, float upper_bound) template void init_sgemm_output(T &dst, T &src0, T &src1, arm_compute::DataType dt) { - dst.allocator()->init(TensorInfo(TensorShape(src1.info()->dimension(0), src0.info()->dimension(1)), 1, dt)); + dst.allocator()->init(TensorInfo(TensorShape(src1.info()->dimension(0), src0.info()->dimension(1), src0.info()->dimension(2)), 1, dt)); } /** This function returns the amount of memory free reading from /proc/meminfo * -- cgit v1.2.1