diff options
Diffstat (limited to 'utils/CommonGraphOptions.cpp')
-rw-r--r-- | utils/CommonGraphOptions.cpp | 87 |
1 files changed, 48 insertions, 39 deletions
diff --git a/utils/CommonGraphOptions.cpp b/utils/CommonGraphOptions.cpp index fa9106ceb5..42524d802d 100644 --- a/utils/CommonGraphOptions.cpp +++ b/utils/CommonGraphOptions.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 ARM Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -23,6 +23,7 @@ */ #include "CommonGraphOptions.h" +#include "arm_compute/core/Utils.h" #include "arm_compute/graph/TypeLoader.h" #include "arm_compute/graph/TypePrinter.h" @@ -36,15 +37,15 @@ namespace { std::pair<unsigned int, unsigned int> parse_validation_range(const std::string &validation_range) { - std::pair<unsigned int /* start */, unsigned int /* end */> range = { 0, std::numeric_limits<unsigned int>::max() }; - if(!validation_range.empty()) + std::pair<unsigned int /* start */, unsigned int /* end */> range = {0, std::numeric_limits<unsigned int>::max()}; + if (!validation_range.empty()) { std::string str; std::stringstream stream(validation_range); // Get first value std::getline(stream, str, ','); - if(stream.fail()) + if (stream.fail()) { return range; } @@ -55,7 +56,7 @@ std::pair<unsigned int, unsigned int> parse_validation_range(const std::string & // Get second value std::getline(stream, str); - if(stream.fail()) + if (stream.fail()) { range.second = range.first; return range; @@ -86,24 +87,27 @@ namespace utils os << "Cache enabled? : " << (common_params.enable_cl_cache ? true_str : false_str) << std::endl; os << "Tuner mode : " << common_params.tuner_mode << std::endl; os << "Tuner file : " << common_params.tuner_file << std::endl; - os << "Fast math enabled? : " << (common_params.fast_math_hint == FastMathHint::Enabled ? true_str : false_str) << std::endl; - if(!common_params.data_path.empty()) + os << "MLGO file : " << common_params.mlgo_file << std::endl; + os << "Fast math enabled? : " << (common_params.fast_math_hint == FastMathHint::Enabled ? true_str : false_str) + << std::endl; + if (!common_params.data_path.empty()) { os << "Data path : " << common_params.data_path << std::endl; } - if(!common_params.image.empty()) + if (!common_params.image.empty()) { os << "Image file : " << common_params.image << std::endl; } - if(!common_params.labels.empty()) + if (!common_params.labels.empty()) { os << "Labels file : " << common_params.labels << std::endl; } - if(!common_params.validation_file.empty()) + if (!common_params.validation_file.empty()) { - os << "Validation range : " << common_params.validation_range_start << "-" << common_params.validation_range_end << std::endl; + os << "Validation range : " << common_params.validation_range_start << "-" << common_params.validation_range_end + << std::endl; os << "Validation file : " << common_params.validation_file << std::endl; - if(!common_params.validation_path.empty()) + if (!common_params.validation_path.empty()) { os << "Validation path : " << common_params.validation_path << std::endl; } @@ -115,6 +119,7 @@ namespace utils CommonGraphOptions::CommonGraphOptions(CommandLineParser &parser) : help(parser.add_option<ToggleOption>("help")), threads(parser.add_option<SimpleOption<int>>("threads", 1)), + batches(parser.add_option<SimpleOption<int>>("batches", 1)), target(), data_type(), data_layout(), @@ -128,34 +133,28 @@ CommonGraphOptions::CommonGraphOptions(CommandLineParser &parser) validation_file(parser.add_option<SimpleOption<std::string>>("validation-file")), validation_path(parser.add_option<SimpleOption<std::string>>("validation-path")), validation_range(parser.add_option<SimpleOption<std::string>>("validation-range")), - tuner_file(parser.add_option<SimpleOption<std::string>>("tuner-file")) + tuner_file(parser.add_option<SimpleOption<std::string>>("tuner-file")), + mlgo_file(parser.add_option<SimpleOption<std::string>>("mlgo-file")) { - std::set<arm_compute::graph::Target> supported_targets - { + std::set<arm_compute::graph::Target> supported_targets{ Target::NEON, Target::CL, - Target::GC, + Target::CLVK, }; - std::set<arm_compute::DataType> supported_data_types - { + std::set<arm_compute::DataType> supported_data_types{ DataType::F16, DataType::F32, DataType::QASYMM8, + DataType::QASYMM8_SIGNED, }; - std::set<DataLayout> supported_data_layouts - { + std::set<DataLayout> supported_data_layouts{ DataLayout::NHWC, DataLayout::NCHW, }; - const std::set<CLTunerMode> supported_tuner_modes - { - CLTunerMode::EXHAUSTIVE, - CLTunerMode::NORMAL, - CLTunerMode::RAPID - }; + const std::set<CLTunerMode> supported_tuner_modes{CLTunerMode::EXHAUSTIVE, CLTunerMode::NORMAL, CLTunerMode::RAPID}; target = parser.add_option<EnumOption<Target>>("target", supported_targets, Target::NEON); data_type = parser.add_option<EnumOption<DataType>>("type", supported_data_types, DataType::F32); @@ -164,12 +163,16 @@ CommonGraphOptions::CommonGraphOptions(CommandLineParser &parser) help->set_help("Show this help message"); threads->set_help("Number of threads to use"); + batches->set_help("Number of batches to use for the inputs"); target->set_help("Target to execute on"); data_type->set_help("Data type to use"); data_layout->set_help("Data layout to use"); enable_tuner->set_help("Enable OpenCL dynamic tuner"); enable_cl_cache->set_help("Enable OpenCL program caches"); - tuner_mode->set_help("Configures the time taken by the tuner to tune. Slow tuner produces the most performant LWS configuration"); + tuner_mode->set_help("Configures the time taken by the tuner to tune. " + "Exhaustive: slowest but produces the most performant LWS configuration. " + "Normal: slow but produces the LWS configurations on par with Exhaustive most of the time. " + "Rapid: fast but produces less performant LWS configurations"); fast_math_hint->set_help("Enable fast math"); data_path->set_help("Path where graph parameters reside"); image->set_help("Input image for the graph"); @@ -178,34 +181,40 @@ CommonGraphOptions::CommonGraphOptions(CommandLineParser &parser) validation_path->set_help("Path to the validation data"); validation_range->set_help("Range of the images to validate for (Format : start,end)"); tuner_file->set_help("File to load/save CLTuner values"); + mlgo_file->set_help("File to load MLGO heuristics"); } CommonGraphParams consume_common_graph_parameters(CommonGraphOptions &options) { - FastMathHint fast_math_hint_value = options.fast_math_hint->value() ? FastMathHint::Enabled : FastMathHint::Disabled; - auto validation_range = parse_validation_range(options.validation_range->value()); + FastMathHint fast_math_hint_value = + options.fast_math_hint->value() ? FastMathHint::Enabled : FastMathHint::Disabled; + auto validation_range = parse_validation_range(options.validation_range->value()); CommonGraphParams common_params; common_params.help = options.help->is_set() ? options.help->value() : false; common_params.threads = options.threads->value(); + common_params.batches = options.batches->value(); common_params.target = options.target->value(); common_params.data_type = options.data_type->value(); - if(options.data_layout->is_set()) + if (options.data_layout->is_set()) { common_params.data_layout = options.data_layout->value(); } - common_params.enable_tuner = options.enable_tuner->is_set() ? options.enable_tuner->value() : false; - common_params.enable_cl_cache = common_params.target == arm_compute::graph::Target::CL ? (options.enable_cl_cache->is_set() ? options.enable_cl_cache->value() : true) : false; - common_params.tuner_mode = options.tuner_mode->value(); - common_params.fast_math_hint = options.fast_math_hint->is_set() ? fast_math_hint_value : FastMathHint::Disabled; - common_params.data_path = options.data_path->value(); - common_params.image = options.image->value(); - common_params.labels = options.labels->value(); - common_params.validation_file = options.validation_file->value(); - common_params.validation_path = options.validation_path->value(); + common_params.enable_tuner = options.enable_tuner->is_set() ? options.enable_tuner->value() : false; + common_params.enable_cl_cache = common_params.target == arm_compute::graph::Target::NEON + ? false + : (options.enable_cl_cache->is_set() ? options.enable_cl_cache->value() : true); + common_params.tuner_mode = options.tuner_mode->value(); + common_params.fast_math_hint = options.fast_math_hint->is_set() ? fast_math_hint_value : FastMathHint::Disabled; + common_params.data_path = options.data_path->value(); + common_params.image = options.image->value(); + common_params.labels = options.labels->value(); + common_params.validation_file = options.validation_file->value(); + common_params.validation_path = options.validation_path->value(); common_params.validation_range_start = validation_range.first; common_params.validation_range_end = validation_range.second; common_params.tuner_file = options.tuner_file->value(); + common_params.mlgo_file = options.mlgo_file->value(); return common_params; } |