From 12be7ab4876f77fecfab903df70791623219b3da Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Tue, 3 Jul 2018 12:06:23 +0100 Subject: COMPMID-1310: Create graph validation executables. Change-Id: I9e0b57b1b83fe5a95777cdaeddba6ecef650bafc Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/138697 Reviewed-by: Anthony Barbier Tested-by: Jenkins --- utils/CommonGraphOptions.cpp | 190 ++++++++++++++++++++++++++ utils/CommonGraphOptions.h | 111 +++++++++++++++ utils/GraphUtils.cpp | 108 +++++++++------ utils/GraphUtils.h | 134 +++++++++++------- utils/ImageLoader.h | 26 ++++ utils/Utils.cpp | 41 +++++- utils/Utils.h | 25 +++- utils/command_line/CommandLineOptions.h | 33 +++++ utils/command_line/CommandLineParser.h | 234 ++++++++++++++++++++++++++++++++ utils/command_line/EnumListOption.h | 150 ++++++++++++++++++++ utils/command_line/EnumOption.h | 136 +++++++++++++++++++ utils/command_line/ListOption.h | 118 ++++++++++++++++ utils/command_line/Option.h | 141 +++++++++++++++++++ utils/command_line/SimpleOption.h | 118 ++++++++++++++++ utils/command_line/ToggleOption.h | 82 +++++++++++ 15 files changed, 1555 insertions(+), 92 deletions(-) create mode 100644 utils/CommonGraphOptions.cpp create mode 100644 utils/CommonGraphOptions.h create mode 100644 utils/command_line/CommandLineOptions.h create mode 100644 utils/command_line/CommandLineParser.h create mode 100644 utils/command_line/EnumListOption.h create mode 100644 utils/command_line/EnumOption.h create mode 100644 utils/command_line/ListOption.h create mode 100644 utils/command_line/Option.h create mode 100644 utils/command_line/SimpleOption.h create mode 100644 utils/command_line/ToggleOption.h (limited to 'utils') diff --git a/utils/CommonGraphOptions.cpp b/utils/CommonGraphOptions.cpp new file mode 100644 index 0000000000..d6ff0516aa --- /dev/null +++ b/utils/CommonGraphOptions.cpp @@ -0,0 +1,190 @@ +/* + * Copyright (c) 2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "CommonGraphOptions.h" + +#include "arm_compute/graph/TypeLoader.h" +#include "arm_compute/graph/TypePrinter.h" + +#include "support/ToolchainSupport.h" + +#include + +using namespace arm_compute::graph; + +namespace +{ +std::pair parse_validation_range(const std::string &validation_range) +{ + std::pair range = { 0, std::numeric_limits::max() }; + if(!validation_range.empty()) + { + std::string str; + std::stringstream stream(validation_range); + + // Get first value + std::getline(stream, str, ','); + if(stream.fail()) + { + return range; + } + else + { + range.first = arm_compute::support::cpp11::stoi(str); + } + + // Get second value + std::getline(stream, str); + if(stream.fail()) + { + range.second = range.first; + return range; + } + else + { + range.second = arm_compute::support::cpp11::stoi(str); + } + } + return range; +} +} // namespace + +namespace arm_compute +{ +namespace utils +{ +::std::ostream &operator<<(::std::ostream &os, const CommonGraphParams &common_params) +{ + std::string false_str = std::string("false"); + std::string true_str = std::string("true"); + + os << "Threads : " << common_params.threads << std::endl; + os << "Target : " << common_params.target << std::endl; + os << "Data type : " << common_params.data_type << std::endl; + os << "Data layout : " << common_params.data_layout << std::endl; + os << "Tuner enabled? : " << (common_params.enable_tuner ? true_str : false_str) << std::endl; + os << "Fast math enabled? : " << (common_params.fast_math_hint == FastMathHint::ENABLED ? true_str : false_str) << std::endl; + if(!common_params.data_path.empty()) + { + os << "Data path : " << common_params.data_path << std::endl; + } + if(!common_params.image.empty()) + { + os << "Image file : " << common_params.image << std::endl; + } + if(!common_params.labels.empty()) + { + os << "Labels file : " << common_params.labels << std::endl; + } + if(!common_params.validation_file.empty()) + { + os << "Validation range : " << common_params.validation_range_start << "-" << common_params.validation_range_end << std::endl; + os << "Validation file : " << common_params.validation_file << std::endl; + if(!common_params.validation_path.empty()) + { + os << "Validation path : " << common_params.validation_path << std::endl; + } + } + + return os; +} + +CommonGraphOptions::CommonGraphOptions(CommandLineParser &parser) + : help(parser.add_option("help")), + threads(parser.add_option>("threads", 1)), + target(), + data_type(), + data_layout(), + enable_tuner(parser.add_option("enable-tuner")), + fast_math_hint(parser.add_option("fast-math")), + data_path(parser.add_option>("data")), + image(parser.add_option>("image")), + labels(parser.add_option>("labels")), + validation_file(parser.add_option>("validation-file")), + validation_path(parser.add_option>("validation-path")), + validation_range(parser.add_option>("validation-range")) +{ + std::set supported_targets + { + Target::NEON, + Target::CL, + Target::GC, + }; + + std::set supported_data_types + { + DataType::F16, + DataType::F32, + DataType::QASYMM8, + }; + + std::set supported_data_layouts + { + DataLayout::NHWC, + DataLayout::NCHW, + }; + + target = parser.add_option>("target", supported_targets, Target::NEON); + data_type = parser.add_option>("type", supported_data_types, DataType::F32); + data_layout = parser.add_option>("layout", supported_data_layouts, DataLayout::NCHW); + + help->set_help("Show this help message"); + threads->set_help("Number of threads to use"); + target->set_help("Target to execute on"); + data_type->set_help("Data type to use"); + data_layout->set_help("Data layout to use"); + enable_tuner->set_help("Enable tuner"); + fast_math_hint->set_help("Enable fast math"); + data_path->set_help("Path where graph parameters reside"); + image->set_help("Input image for the graph"); + labels->set_help("File containing the output labels"); + validation_file->set_help("File used to validate the graph"); + validation_path->set_help("Path to the validation data"); + validation_range->set_help("Range of the images to validate for (Format : start,end)"); +} + +CommonGraphParams consume_common_graph_parameters(CommonGraphOptions &options) +{ + FastMathHint fast_math_hint_value = options.fast_math_hint->value() ? FastMathHint::ENABLED : FastMathHint::DISABLED; + auto validation_range = parse_validation_range(options.validation_range->value()); + + CommonGraphParams common_params; + common_params.help = options.help->is_set() ? options.help->value() : false; + common_params.threads = options.threads->value(); + common_params.target = options.target->value(); + common_params.data_type = options.data_type->value(); + common_params.data_layout = options.data_layout->value(); + common_params.enable_tuner = options.enable_tuner->is_set() ? options.enable_tuner->value() : false; + common_params.fast_math_hint = options.fast_math_hint->is_set() ? fast_math_hint_value : FastMathHint::DISABLED; + common_params.data_path = options.data_path->value(); + common_params.image = options.image->value(); + common_params.labels = options.labels->value(); + common_params.validation_file = options.validation_file->value(); + common_params.validation_path = options.validation_path->value(); + common_params.validation_range_start = validation_range.first; + common_params.validation_range_end = validation_range.second; + + return common_params; +} +} // namespace utils +} // namespace arm_compute diff --git a/utils/CommonGraphOptions.h b/utils/CommonGraphOptions.h new file mode 100644 index 0000000000..ef2e4fb946 --- /dev/null +++ b/utils/CommonGraphOptions.h @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_EXAMPLES_UTILS_COMMON_GRAPH_OPTIONS +#define ARM_COMPUTE_EXAMPLES_UTILS_COMMON_GRAPH_OPTIONS + +#include "utils/command_line/CommandLineOptions.h" +#include "utils/command_line/CommandLineParser.h" + +#include "arm_compute/graph/TypeLoader.h" +#include "arm_compute/graph/TypePrinter.h" + +namespace arm_compute +{ +namespace utils +{ +/** Structure holding all the common graph parameters */ +struct CommonGraphParams +{ + bool help{ false }; + int threads{ 0 }; + arm_compute::graph::Target target{ arm_compute::graph::Target::NEON }; + arm_compute::DataType data_type{ DataType::F32 }; + arm_compute::DataLayout data_layout{ DataLayout::NCHW }; + bool enable_tuner{ false }; + arm_compute::graph::FastMathHint fast_math_hint{ arm_compute::graph::FastMathHint::DISABLED }; + std::string data_path{}; + std::string image{}; + std::string labels{}; + std::string validation_file{}; + std::string validation_path{}; + unsigned int validation_range_start{ 0 }; + unsigned int validation_range_end{ std::numeric_limits::max() }; +}; + +/** Formatted output of the CommonGraphParams type + * + * @param[out] os Output stream. + * @param[in] common_params Common parameters to output + * + * @return Modified output stream. + */ +::std::ostream &operator<<(::std::ostream &os, const CommonGraphParams &common_params); + +/** Common command line options used to configure the graph examples + * + * The options in this object get populated when "parse()" is called on the parser used to construct it. + * The expected workflow is: + * + * CommandLineParser parser; + * CommonOptions options( parser ); + * parser.parse(argc, argv); + */ +class CommonGraphOptions +{ +public: + /** Constructor + * + * @param[in,out] parser A parser on which "parse()" hasn't been called yet. + */ + CommonGraphOptions(CommandLineParser &parser); + /** Prevent instances of this class from being copy constructed */ + CommonGraphOptions(const CommonGraphOptions &) = delete; + /** Prevent instances of this class from being copied */ + CommonGraphOptions &operator=(const CommonGraphOptions &) = delete; + + ToggleOption *help; /**< Show help option */ + SimpleOption *threads; /**< Number of threads option */ + EnumOption *target; /**< Graph execution target */ + EnumOption *data_type; /**< Graph data type */ + EnumOption *data_layout; /**< Graph data layout */ + ToggleOption *enable_tuner; /**< Enable tuner */ + ToggleOption *fast_math_hint; /**< Fast math hint */ + SimpleOption *data_path; /**< Trainable parameters path */ + SimpleOption *image; /**< Image */ + SimpleOption *labels; /**< Labels */ + SimpleOption *validation_file; /**< Validation file */ + SimpleOption *validation_path; /**< Validation data path */ + SimpleOption *validation_range; /**< Validation range */ +}; + +/** Consumes the common graph options and creates a structure containing any information + * + * @param[in] options Options to consume + * + * @return Structure containing the commnon graph parameters + */ +CommonGraphParams consume_common_graph_parameters(CommonGraphOptions &options); +} // namespace utils +} // namespace arm_compute +#endif /* ARM_COMPUTE_EXAMPLES_UTILS_COMMON_GRAPH_OPTIONS */ diff --git a/utils/GraphUtils.cpp b/utils/GraphUtils.cpp index d94dcb0d86..5c1fda5ca6 100644 --- a/utils/GraphUtils.cpp +++ b/utils/GraphUtils.cpp @@ -26,11 +26,13 @@ #include "arm_compute/core/Helpers.h" #include "arm_compute/core/Types.h" +#include "arm_compute/graph/Logger.h" #include "arm_compute/runtime/SubTensor.h" #include "utils/ImageLoader.h" #include "utils/Utils.h" #include +#include using namespace arm_compute::graph_utils; @@ -169,17 +171,18 @@ bool NumPyAccessor::access_tensor(ITensor &tensor) return false; } -PPMAccessor::PPMAccessor(std::string ppm_path, bool bgr, std::unique_ptr preprocessor) - : _ppm_path(std::move(ppm_path)), _bgr(bgr), _preprocessor(std::move(preprocessor)) +ImageAccessor::ImageAccessor(std::string filename, bool bgr, std::unique_ptr preprocessor) + : _filename(std::move(filename)), _bgr(bgr), _preprocessor(std::move(preprocessor)) { } -bool PPMAccessor::access_tensor(ITensor &tensor) +bool ImageAccessor::access_tensor(ITensor &tensor) { - utils::PPMLoader ppm; + auto image_loader = utils::ImageLoaderFactory::create(_filename); + ARM_COMPUTE_ERROR_ON_MSG(image_loader == nullptr, "Unsupported image type"); - // Open PPM file - ppm.open(_ppm_path); + // Open image file + image_loader->open(_filename); // Get permutated shape and permutation parameters TensorShape permuted_shape = tensor.info()->tensor_shape(); @@ -188,11 +191,12 @@ bool PPMAccessor::access_tensor(ITensor &tensor) { std::tie(permuted_shape, perm) = compute_permutation_paramaters(tensor.info()->tensor_shape(), tensor.info()->data_layout()); } - ARM_COMPUTE_ERROR_ON_MSG(ppm.width() != permuted_shape.x() || ppm.height() != permuted_shape.y(), - "Failed to load image file: dimensions [%d,%d] not correct, expected [%d,%d].", ppm.width(), ppm.height(), permuted_shape.x(), permuted_shape.y()); + ARM_COMPUTE_ERROR_ON_MSG(image_loader->width() != permuted_shape.x() || image_loader->height() != permuted_shape.y(), + "Failed to load image file: dimensions [%d,%d] not correct, expected [%d,%d].", + image_loader->width(), image_loader->height(), permuted_shape.x(), permuted_shape.y()); // Fill the tensor with the PPM content (BGR) - ppm.fill_planar_tensor(tensor, _bgr); + image_loader->fill_planar_tensor(tensor, _bgr); // Preprocess tensor if(_preprocessor) @@ -203,12 +207,13 @@ bool PPMAccessor::access_tensor(ITensor &tensor) return true; } -ValidationInputAccessor::ValidationInputAccessor(const std::string &image_list, - std::string images_path, - bool bgr, - unsigned int start, - unsigned int end) - : _path(std::move(images_path)), _images(), _bgr(bgr), _offset(0) +ValidationInputAccessor::ValidationInputAccessor(const std::string &image_list, + std::string images_path, + std::unique_ptr preprocessor, + bool bgr, + unsigned int start, + unsigned int end) + : _path(std::move(images_path)), _images(), _preprocessor(std::move(preprocessor)), _bgr(bgr), _offset(0) { ARM_COMPUTE_ERROR_ON_MSG(start > end, "Invalid validation range!"); @@ -247,7 +252,9 @@ bool ValidationInputAccessor::access_tensor(arm_compute::ITensor &tensor) utils::JPEGLoader jpeg; // Open JPEG file - jpeg.open(_path + _images[_offset++]); + std::string image_name = _path + _images[_offset++]; + jpeg.open(image_name); + ARM_COMPUTE_LOG_GRAPH_INFO("Validating " << image_name << std::endl); // Get permutated shape and permutation parameters TensorShape permuted_shape = tensor.info()->tensor_shape(); @@ -261,22 +268,26 @@ bool ValidationInputAccessor::access_tensor(arm_compute::ITensor &tensor) "Failed to load image file: dimensions [%d,%d] not correct, expected [%d,%d].", jpeg.width(), jpeg.height(), permuted_shape.x(), permuted_shape.y()); - // Fill the tensor with the PPM content (BGR) + // Fill the tensor with the JPEG content (BGR) jpeg.fill_planar_tensor(tensor, _bgr); + + // Preprocess tensor + if(_preprocessor) + { + _preprocessor->preprocess(tensor); + } } return ret; } ValidationOutputAccessor::ValidationOutputAccessor(const std::string &image_list, - size_t top_n, std::ostream &output_stream, unsigned int start, unsigned int end) - : _results(), _output_stream(output_stream), _top_n(top_n), _offset(0), _positive_samples(0) + : _results(), _output_stream(output_stream), _offset(0), _positive_samples_top1(0), _positive_samples_top5(0) { ARM_COMPUTE_ERROR_ON_MSG(start > end, "Invalid validation range!"); - ARM_COMPUTE_ERROR_ON(top_n == 0); std::ifstream ifs; try @@ -308,13 +319,15 @@ ValidationOutputAccessor::ValidationOutputAccessor(const std::string &image_list void ValidationOutputAccessor::reset() { - _offset = 0; - _positive_samples = 0; + _offset = 0; + _positive_samples_top1 = 0; + _positive_samples_top5 = 0; } bool ValidationOutputAccessor::access_tensor(arm_compute::ITensor &tensor) { - if(_offset < _results.size()) + bool ret = _offset < _results.size(); + if(ret) { // Get results std::vector tensor_results; @@ -332,30 +345,16 @@ bool ValidationOutputAccessor::access_tensor(arm_compute::ITensor &tensor) // Check if tensor results are within top-n accuracy size_t correct_label = _results[_offset++]; - auto is_valid_label = [&](size_t label) - { - return label == correct_label; - }; - if(std::any_of(std::begin(tensor_results), std::begin(tensor_results) + _top_n - 1, is_valid_label)) - { - ++_positive_samples; - } + aggregate_sample(tensor_results, _positive_samples_top1, 1, correct_label); + aggregate_sample(tensor_results, _positive_samples_top5, 5, correct_label); } // Report top_n accuracy - bool ret = _offset >= _results.size(); - if(ret) + if(_offset >= _results.size()) { - size_t total_samples = _results.size(); - size_t negative_samples = total_samples - _positive_samples; - float accuracy = _positive_samples / static_cast(total_samples); - - _output_stream << "----------Top " << _top_n << " accuracy ----------" << std::endl - << std::endl; - _output_stream << "Positive samples : " << _positive_samples << std::endl; - _output_stream << "Negative samples : " << negative_samples << std::endl; - _output_stream << "Accuracy : " << accuracy << std::endl; + report_top_n(1, _results.size(), _positive_samples_top1); + report_top_n(5, _results.size(), _positive_samples_top5); } return ret; @@ -383,6 +382,31 @@ std::vector ValidationOutputAccessor::access_predictions_tensor(arm_comp return index; } +void ValidationOutputAccessor::aggregate_sample(const std::vector &res, size_t &positive_samples, size_t top_n, size_t correct_label) +{ + auto is_valid_label = [correct_label](size_t label) + { + return label == correct_label; + }; + + if(std::any_of(std::begin(res), std::begin(res) + top_n, is_valid_label)) + { + ++positive_samples; + } +} + +void ValidationOutputAccessor::report_top_n(size_t top_n, size_t total_samples, size_t positive_samples) +{ + size_t negative_samples = total_samples - positive_samples; + float accuracy = positive_samples / static_cast(total_samples); + + _output_stream << "----------Top " << top_n << " accuracy ----------" << std::endl + << std::endl; + _output_stream << "Positive samples : " << positive_samples << std::endl; + _output_stream << "Negative samples : " << negative_samples << std::endl; + _output_stream << "Accuracy : " << accuracy << std::endl; +} + TopNPredictionsAccessor::TopNPredictionsAccessor(const std::string &labels_path, size_t top_n, std::ostream &output_stream) : _labels(), _output_stream(output_stream), _top_n(top_n) { diff --git a/utils/GraphUtils.h b/utils/GraphUtils.h index 768c608d26..8558b9066c 100644 --- a/utils/GraphUtils.h +++ b/utils/GraphUtils.h @@ -31,6 +31,8 @@ #include "arm_compute/graph/Types.h" #include "arm_compute/runtime/Tensor.h" +#include "utils/CommonGraphOptions.h" + #include #include #include @@ -150,25 +152,25 @@ private: std::ostream &_output_stream; }; -/** PPM accessor class */ -class PPMAccessor final : public graph::ITensorAccessor +/** Image accessor class */ +class ImageAccessor final : public graph::ITensorAccessor { public: /** Constructor * - * @param[in] ppm_path Path to PPM file + * @param[in] filename Image file * @param[in] bgr (Optional) Fill the first plane with blue channel (default = false - RGB format) - * @param[in] preprocessor (Optional) PPM pre-processing object + * @param[in] preprocessor (Optional) Image pre-processing object */ - PPMAccessor(std::string ppm_path, bool bgr = true, std::unique_ptr preprocessor = nullptr); + ImageAccessor(std::string filename, bool bgr = true, std::unique_ptr preprocessor = nullptr); /** Allow instances of this class to be move constructed */ - PPMAccessor(PPMAccessor &&) = default; + ImageAccessor(ImageAccessor &&) = default; // Inherited methods overriden: bool access_tensor(ITensor &tensor) override; private: - const std::string _ppm_path; + const std::string _filename; const bool _bgr; std::unique_ptr _preprocessor; }; @@ -179,28 +181,31 @@ class ValidationInputAccessor final : public graph::ITensorAccessor public: /** Constructor * - * @param[in] image_list File containing all the images to validate - * @param[in] images_path Path to images. - * @param[in] bgr (Optional) Fill the first plane with blue channel (default = false - RGB format) - * @param[in] start (Optional) Start range - * @param[in] end (Optional) End range + * @param[in] image_list File containing all the images to validate + * @param[in] images_path Path to images. + * @param[in] bgr (Optional) Fill the first plane with blue channel (default = false - RGB format) + * @param[in] preprocessor (Optional) Image pre-processing object (default = nullptr) + * @param[in] start (Optional) Start range + * @param[in] end (Optional) End range * * @note Range is defined as [start, end] */ - ValidationInputAccessor(const std::string &image_list, - std::string images_path, - bool bgr = true, - unsigned int start = 0, - unsigned int end = 0); + ValidationInputAccessor(const std::string &image_list, + std::string images_path, + std::unique_ptr preprocessor = nullptr, + bool bgr = true, + unsigned int start = 0, + unsigned int end = 0); // Inherited methods overriden: bool access_tensor(ITensor &tensor) override; private: - std::string _path; - std::vector _images; - bool _bgr; - size_t _offset; + std::string _path; + std::vector _images; + std::unique_ptr _preprocessor; + bool _bgr; + size_t _offset; }; /** Output Accessor used for network validation */ @@ -210,7 +215,6 @@ public: /** Default Constructor * * @param[in] image_list File containing all the images and labels results - * @param[in] top_n (Optional) Top N accuracy (Defaults to 5) * @param[out] output_stream (Optional) Output stream (Defaults to the standard output stream) * @param[in] start (Optional) Start range * @param[in] end (Optional) End range @@ -218,7 +222,6 @@ public: * @note Range is defined as [start, end] */ ValidationOutputAccessor(const std::string &image_list, - size_t top_n = 5, std::ostream &output_stream = std::cout, unsigned int start = 0, unsigned int end = 0); @@ -237,13 +240,28 @@ private: */ template std::vector access_predictions_tensor(ITensor &tensor); + /** Aggregates the results of a sample + * + * @param[in] res Vector containing the results of a graph + * @param[in,out] positive_samples Positive samples to be updated + * @param[in] top_n Top n accuracy to measure + * @param[in] correct_label Correct label of the current sample + */ + void aggregate_sample(const std::vector &res, size_t &positive_samples, size_t top_n, size_t correct_label); + /** Reports top N accuracy + * + * @param[in] top_n Top N accuracy that is being reported + * @param[in] total_samples Total number of samples + * @param[in] positive_samples Positive samples + */ + void report_top_n(size_t top_n, size_t total_samples, size_t positive_samples); private: std::vector _results; std::ostream &_output_stream; - size_t _top_n; size_t _offset; - size_t _positive_samples; + size_t _positive_samples_top1; + size_t _positive_samples_top5; }; /** Result accessor class */ @@ -359,56 +377,78 @@ inline std::unique_ptr get_weights_accessor(const std::s } } -/** Generates appropriate input accessor according to the specified ppm_path +/** Generates appropriate input accessor according to the specified graph parameters * - * @note If ppm_path is empty will generate a DummyAccessor else will generate a PPMAccessor - * - * @param[in] ppm_path Path to PPM file - * @param[in] preprocessor Preproccessor object - * @param[in] bgr (Optional) Fill the first plane with blue channel (default = true) + * @param[in] graph_parameters Graph parameters + * @param[in] preprocessor (Optional) Preproccessor object + * @param[in] bgr (Optional) Fill the first plane with blue channel (default = true) * * @return An appropriate tensor accessor */ -inline std::unique_ptr get_input_accessor(const std::string &ppm_path, - std::unique_ptr preprocessor = nullptr, - bool bgr = true) +inline std::unique_ptr get_input_accessor(const arm_compute::utils::CommonGraphParams &graph_parameters, + std::unique_ptr preprocessor = nullptr, + bool bgr = true) { - if(ppm_path.empty()) + if(!graph_parameters.validation_file.empty()) { - return arm_compute::support::cpp14::make_unique(); + return arm_compute::support::cpp14::make_unique(graph_parameters.validation_file, + graph_parameters.validation_path, + std::move(preprocessor), + bgr, + graph_parameters.validation_range_start, + graph_parameters.validation_range_end); } else { - if(arm_compute::utility::endswith(ppm_path, ".npy")) + const std::string &image_file = graph_parameters.image; + if(arm_compute::utility::endswith(image_file, ".npy")) + { + return arm_compute::support::cpp14::make_unique(image_file); + } + else if(arm_compute::utility::endswith(image_file, ".jpeg") + || arm_compute::utility::endswith(image_file, ".jpg") + || arm_compute::utility::endswith(image_file, ".ppm")) { - return arm_compute::support::cpp14::make_unique(ppm_path); + return arm_compute::support::cpp14::make_unique(image_file, bgr, std::move(preprocessor)); } else { - return arm_compute::support::cpp14::make_unique(ppm_path, bgr, std::move(preprocessor)); + return arm_compute::support::cpp14::make_unique(); } } } -/** Generates appropriate output accessor according to the specified labels_path +/** Generates appropriate output accessor according to the specified graph parameters * - * @note If labels_path is empty will generate a DummyAccessor else will generate a TopNPredictionsAccessor + * @note If the output accessor is requested to validate the graph then ValidationOutputAccessor is generated + * else if output_accessor_file is empty will generate a DummyAccessor else will generate a TopNPredictionsAccessor * - * @param[in] labels_path Path to labels text file - * @param[in] top_n (Optional) Number of output classes to print - * @param[out] output_stream (Optional) Output stream + * @param[in] graph_parameters Graph parameters + * @param[in] top_n (Optional) Number of output classes to print (default = 5) + * @param[in] is_validation (Optional) Validation flag (default = false) + * @param[out] output_stream (Optional) Output stream (default = std::cout) * * @return An appropriate tensor accessor */ -inline std::unique_ptr get_output_accessor(const std::string &labels_path, size_t top_n = 5, std::ostream &output_stream = std::cout) +inline std::unique_ptr get_output_accessor(const arm_compute::utils::CommonGraphParams &graph_parameters, + size_t top_n = 5, + bool is_validation = false, + std::ostream &output_stream = std::cout) { - if(labels_path.empty()) + if(!graph_parameters.validation_file.empty()) + { + return arm_compute::support::cpp14::make_unique(graph_parameters.validation_file, + output_stream, + graph_parameters.validation_range_start, + graph_parameters.validation_range_end); + } + else if(graph_parameters.labels.empty()) { return arm_compute::support::cpp14::make_unique(0); } else { - return arm_compute::support::cpp14::make_unique(labels_path, top_n, output_stream); + return arm_compute::support::cpp14::make_unique(graph_parameters.labels, top_n, output_stream); } } /** Generates appropriate npy output accessor according to the specified npy_path diff --git a/utils/ImageLoader.h b/utils/ImageLoader.h index edc89286a2..cc9619d3f1 100644 --- a/utils/ImageLoader.h +++ b/utils/ImageLoader.h @@ -486,6 +486,32 @@ private: bool _is_loaded; std::unique_ptr _data; }; + +/** Factory for generating appropriate image loader**/ +class ImageLoaderFactory final +{ +public: + /** Create an image loader depending on the image type + * + * @param[in] filename File than needs to be loaded + * + * @return Image loader + */ + static std::unique_ptr create(const std::string &filename) + { + ImageType type = arm_compute::utils::get_image_type_from_file(filename); + switch(type) + { + case ImageType::PPM: + return support::cpp14::make_unique(); + case ImageType::JPEG: + return support::cpp14::make_unique(); + case ImageType::UNKNOWN: + default: + return nullptr; + } + } +}; } // namespace utils } // namespace arm_compute #endif /* __UTILS_IMAGE_LOADER_H__*/ diff --git a/utils/Utils.cpp b/utils/Utils.cpp index a5c6a95a2a..133248e30c 100644 --- a/utils/Utils.cpp +++ b/utils/Utils.cpp @@ -74,7 +74,11 @@ int run_example(int argc, char **argv, std::unique_ptr example) try { - example->do_setup(argc, argv); + bool status = example->do_setup(argc, argv); + if(!status) + { + return 1; + } example->do_run(); example->do_teardown(); @@ -141,6 +145,41 @@ void draw_detection_rectangle(ITensor *tensor, const DetectionWindow &rect, uint } } +ImageType get_image_type_from_file(const std::string &filename) +{ + ImageType type = ImageType::UNKNOWN; + + try + { + // Open file + std::ifstream fs; + fs.exceptions(std::ifstream::failbit | std::ifstream::badbit); + fs.open(filename, std::ios::in | std::ios::binary); + + // Identify type from magic number + std::array magic_number{ { 0 } }; + fs >> magic_number[0] >> magic_number[1]; + + // PPM check + if(static_cast(magic_number[0]) == 'P' && static_cast(magic_number[1]) == '6') + { + type = ImageType::PPM; + } + else if(magic_number[0] == 0xFF && magic_number[1] == 0xD8) + { + type = ImageType::JPEG; + } + + fs.close(); + } + catch(std::runtime_error &e) + { + ARM_COMPUTE_ERROR("Accessing %s: %s", filename.c_str(), e.what()); + } + + return type; +} + std::tuple parse_ppm_header(std::ifstream &fs) { // Check the PPM magic number is valid diff --git a/utils/Utils.h b/utils/Utils.h index c18ad217a4..ced6b147f1 100644 --- a/utils/Utils.h +++ b/utils/Utils.h @@ -55,6 +55,14 @@ namespace arm_compute { namespace utils { +/** Supported image types */ +enum class ImageType +{ + UNKNOWN, + PPM, + JPEG +}; + /** Abstract Example class. * * All examples have to inherit from this class. @@ -66,8 +74,13 @@ public: * * @param[in] argc Argument count. * @param[in] argv Argument values. + * + * @return True in case of no errors in setup else false */ - virtual void do_setup(int argc, char **argv) {}; + virtual bool do_setup(int argc, char **argv) + { + return true; + }; /** Run the example. */ virtual void do_run() {}; /** Teardown the example. */ @@ -101,6 +114,14 @@ int run_example(int argc, char **argv) */ void draw_detection_rectangle(arm_compute::ITensor *tensor, const arm_compute::DetectionWindow &rect, uint8_t r, uint8_t g, uint8_t b); +/** Gets image type given a file + * + * @param[in] filename File to identify its image type + * + * @return Image type + */ +ImageType get_image_type_from_file(const std::string &filename); + /** Parse the ppm header from an input file stream. At the end of the execution, * the file position pointer will be located at the first pixel stored in the ppm file * @@ -167,7 +188,7 @@ inline std::string get_typestring(DataType data_type) case DataType::SIZET: return endianness + "u" + support::cpp11::to_string(sizeof(size_t)); default: - ARM_COMPUTE_ERROR("NOT SUPPORTED!"); + ARM_COMPUTE_ERROR("Data type not supported"); } } diff --git a/utils/command_line/CommandLineOptions.h b/utils/command_line/CommandLineOptions.h new file mode 100644 index 0000000000..8f82815020 --- /dev/null +++ b/utils/command_line/CommandLineOptions.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_UTILS_COMMANDLINEOPTIONS +#define ARM_COMPUTE_UTILS_COMMANDLINEOPTIONS + +#include "EnumListOption.h" +#include "EnumOption.h" +#include "ListOption.h" +#include "Option.h" +#include "ToggleOption.h" + +#endif /* ARM_COMPUTE_UTILS_COMMANDLINEOPTIONS */ diff --git a/utils/command_line/CommandLineParser.h b/utils/command_line/CommandLineParser.h new file mode 100644 index 0000000000..06c4bf5e2f --- /dev/null +++ b/utils/command_line/CommandLineParser.h @@ -0,0 +1,234 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_UTILS_COMMANDLINEPARSER +#define ARM_COMPUTE_UTILS_COMMANDLINEPARSER + +#include "Option.h" +#include "arm_compute/core/utils/misc/Utility.h" +#include "support/ToolchainSupport.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace arm_compute +{ +namespace utils +{ +/** Class to parse command line arguments. */ +class CommandLineParser final +{ +public: + /** Default constructor. */ + CommandLineParser() = default; + + /** Function to add a new option to the parser. + * + * @param[in] name Name of the option. Will be available under --name=VALUE. + * @param[in] args Option specific configuration arguments. + * + * @return Pointer to the option. The option is owned by the parser. + */ + template + T *add_option(const std::string &name, As &&... args); + + /** Function to add a new positional argument to the parser. + * + * @param[in] args Option specific configuration arguments. + * + * @return Pointer to the option. The option is owned by the parser. + */ + template + T *add_positional_option(As &&... args); + + /** Parses the command line arguments and updates the options accordingly. + * + * @param[in] argc Number of arguments. + * @param[in] argv Arguments. + */ + void parse(int argc, char **argv); + + /** Validates the previously parsed command line arguments. + * + * Validation fails if not all required options are provided. Additionally + * warnings are generated for options that have illegal values or unknown + * options. + * + * @return True if all required options have been provided. + */ + bool validate() const; + + /** Prints a help message for all configured options. + * + * @param[in] program_name Name of the program to be used in the help message. + */ + void print_help(const std::string &program_name) const; + +private: + using OptionsMap = std::map>; + using PositionalOptionsVector = std::vector>; + + OptionsMap _options{}; + PositionalOptionsVector _positional_options{}; + std::vector _unknown_options{}; + std::vector _invalid_options{}; +}; + +template +inline T *CommandLineParser::add_option(const std::string &name, As &&... args) +{ + auto result = _options.emplace(name, support::cpp14::make_unique(name, std::forward(args)...)); + return static_cast(result.first->second.get()); +} + +template +inline T *CommandLineParser::add_positional_option(As &&... args) +{ + _positional_options.emplace_back(support::cpp14::make_unique(std::forward(args)...)); + return static_cast(_positional_options.back().get()); +} + +inline void CommandLineParser::parse(int argc, char **argv) +{ + const std::regex option_regex{ "--((?:no-)?)([^=]+)(?:=(.*))?" }; + + const auto set_option = [&](const std::string & option, const std::string & name, const std::string & value) + { + if(_options.find(name) == _options.end()) + { + _unknown_options.push_back(option); + return; + } + + const bool success = _options[name]->parse(value); + + if(!success) + { + _invalid_options.push_back(option); + } + }; + + unsigned int positional_index = 0; + + for(int i = 1; i < argc; ++i) + { + std::string mixed_case_opt{ argv[i] }; + int equal_sign = mixed_case_opt.find('='); + int pos = (equal_sign == -1) ? strlen(argv[i]) : equal_sign; + + const std::string option = arm_compute::utility::tolower(mixed_case_opt.substr(0, pos)) + mixed_case_opt.substr(pos); + std::smatch option_matches; + + if(std::regex_match(option, option_matches, option_regex)) + { + // Boolean option + if(option_matches.str(3).empty()) + { + set_option(option, option_matches.str(2), option_matches.str(1).empty() ? "true" : "false"); + } + else + { + // Can't have "no-" and a value + if(!option_matches.str(1).empty()) + { + _invalid_options.emplace_back(option); + } + else + { + set_option(option, option_matches.str(2), option_matches.str(3)); + } + } + } + else + { + if(positional_index >= _positional_options.size()) + { + _invalid_options.push_back(mixed_case_opt); + } + else + { + _positional_options[positional_index]->parse(mixed_case_opt); + ++positional_index; + } + } + } +} + +inline bool CommandLineParser::validate() const +{ + bool is_valid = true; + + for(const auto &option : _options) + { + if(option.second->is_required() && !option.second->is_set()) + { + is_valid = false; + std::cerr << "ERROR: Option '" << option.second->name() << "' is required but not given!\n"; + } + } + + for(const auto &option : _positional_options) + { + if(option->is_required() && !option->is_set()) + { + is_valid = false; + std::cerr << "ERROR: Option '" << option->name() << "' is required but not given!\n"; + } + } + + for(const auto &option : _unknown_options) + { + std::cerr << "WARNING: Skipping unknown option '" << option << "'!\n"; + } + + for(const auto &option : _invalid_options) + { + std::cerr << "WARNING: Skipping invalid option '" << option << "'!\n"; + } + + return is_valid; +} + +inline void CommandLineParser::print_help(const std::string &program_name) const +{ + std::cout << "usage: " << program_name << " \n"; + + for(const auto &option : _options) + { + std::cout << option.second->help() << "\n"; + } + + for(const auto &option : _positional_options) + { + //FIXME: Print help string as well + std::cout << option->name() << "\n"; + } +} +} // namespace utils +} // namespace arm_compute +#endif /* ARM_COMPUTE_UTILS_COMMANDLINEPARSER */ diff --git a/utils/command_line/EnumListOption.h b/utils/command_line/EnumListOption.h new file mode 100644 index 0000000000..834becbaef --- /dev/null +++ b/utils/command_line/EnumListOption.h @@ -0,0 +1,150 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_UTILS_ENUMLISTOPTION +#define ARM_COMPUTE_UTILS_ENUMLISTOPTION + +#include "Option.h" + +#include +#include +#include +#include +#include +#include + +namespace arm_compute +{ +namespace utils +{ +/** Implementation of an option that accepts any number of values from a fixed set. */ +template +class EnumListOption : public Option +{ +public: + /** Construct option with allowed values. + * + * @param[in] name Name of the option. + * @param[in] allowed_values Set of allowed values for the option. + */ + EnumListOption(std::string name, std::set allowed_values); + + /** Construct option with allowed values, a fixed number of accepted values and default values for the option. + * + * @param[in] name Name of the option. + * @param[in] allowed_values Set of allowed values for the option. + * @param[in] default_values Default values. + */ + EnumListOption(std::string name, std::set allowed_values, std::initializer_list &&default_values); + + bool parse(std::string value) override; + std::string help() const override; + + /** Get the values of the option. + * + * @return a list of the selected option values. + */ + const std::vector &value() const; + +private: + std::vector _values{}; + std::set _allowed_values{}; +}; + +template +inline EnumListOption::EnumListOption(std::string name, std::set allowed_values) + : Option{ std::move(name) }, _allowed_values{ std::move(allowed_values) } +{ +} + +template +inline EnumListOption::EnumListOption(std::string name, std::set allowed_values, std::initializer_list &&default_values) + : Option{ std::move(name), false, true }, _values{ std::forward>(default_values) }, _allowed_values{ std::move(allowed_values) } +{ +} + +template +bool EnumListOption::parse(std::string value) +{ + // Remove default values + _values.clear(); + _is_set = true; + + std::stringstream stream{ value }; + std::string item; + + while(!std::getline(stream, item, ',').fail()) + { + try + { + std::stringstream item_stream(item); + T typed_value{}; + + item_stream >> typed_value; + + if(!item_stream.fail()) + { + if(_allowed_values.count(typed_value) == 0) + { + _is_set = false; + continue; + } + + _values.emplace_back(typed_value); + } + + _is_set = _is_set && !item_stream.fail(); + } + catch(const std::invalid_argument &) + { + _is_set = false; + } + } + + return _is_set; +} + +template +std::string EnumListOption::help() const +{ + std::stringstream msg; + msg << "--" + name() + "={"; + + for(const auto &value : _allowed_values) + { + msg << value << ","; + } + + msg << "}[,{...}[,...]] - " << _help; + + return msg.str(); +} + +template +inline const std::vector &EnumListOption::value() const +{ + return _values; +} +} // namespace utils +} // namespace arm_compute +#endif /* ARM_COMPUTE_UTILS_ENUMLISTOPTION */ diff --git a/utils/command_line/EnumOption.h b/utils/command_line/EnumOption.h new file mode 100644 index 0000000000..b775db23fb --- /dev/null +++ b/utils/command_line/EnumOption.h @@ -0,0 +1,136 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_UTILS_ENUMOPTION +#define ARM_COMPUTE_UTILS_ENUMOPTION + +#include "SimpleOption.h" + +#include +#include +#include +#include + +namespace arm_compute +{ +namespace utils +{ +/** Implementation of a simple option that accepts a value from a fixed set. */ +template +class EnumOption : public SimpleOption +{ +public: + /** Construct option with allowed values. + * + * @param[in] name Name of the option. + * @param[in] allowed_values Set of allowed values for the option. + */ + EnumOption(std::string name, std::set allowed_values); + + /** Construct option with allowed values, a fixed number of accepted values and default values for the option. + * + * @param[in] name Name of the option. + * @param[in] allowed_values Set of allowed values for the option. + * @param[in] default_value Default value. + */ + EnumOption(std::string name, std::set allowed_values, T default_value); + + bool parse(std::string value) override; + std::string help() const override; + + /** Get the selected value. + * + * @return get the selected enum value. + */ + const T &value() const; + +private: + std::set _allowed_values{}; +}; + +template +inline EnumOption::EnumOption(std::string name, std::set allowed_values) + : SimpleOption{ std::move(name) }, _allowed_values{ std::move(allowed_values) } +{ +} + +template +inline EnumOption::EnumOption(std::string name, std::set allowed_values, T default_value) + : SimpleOption{ std::move(name), std::move(default_value) }, _allowed_values{ std::move(allowed_values) } +{ +} + +template +bool EnumOption::parse(std::string value) +{ + try + { + std::stringstream stream{ value }; + T typed_value{}; + + stream >> typed_value; + + if(!stream.fail()) + { + if(_allowed_values.count(typed_value) == 0) + { + return false; + } + + this->_value = std::move(typed_value); + this->_is_set = true; + return true; + } + + return false; + } + catch(const std::invalid_argument &) + { + return false; + } +} + +template +std::string EnumOption::help() const +{ + std::stringstream msg; + msg << "--" + this->name() + "={"; + + for(const auto &value : _allowed_values) + { + msg << value << ","; + } + + msg << "} - " << this->_help; + + return msg.str(); +} + +template +inline const T &EnumOption::value() const +{ + return this->_value; +} +} // namespace utils +} // namespace arm_compute +#endif /* ARM_COMPUTE_UTILS_ENUMOPTION */ diff --git a/utils/command_line/ListOption.h b/utils/command_line/ListOption.h new file mode 100644 index 0000000000..209a85d968 --- /dev/null +++ b/utils/command_line/ListOption.h @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_UTILS_LISTOPTION +#define ARM_COMPUTE_UTILS_LISTOPTION + +#include "Option.h" + +#include +#include +#include +#include +#include + +namespace arm_compute +{ +namespace utils +{ +/** Implementation of an option that accepts any number of values. */ +template +class ListOption : public Option +{ +public: + using Option::Option; + + /** Construct the option with the given default values. + * + * @param[in] name Name of the option. + * @param[in] default_values Default values. + */ + ListOption(std::string name, std::initializer_list &&default_values); + + bool parse(std::string value) override; + std::string help() const override; + + /** Get the list of option values. + * + * @return get the list of option values. + */ + const std::vector &value() const; + +private: + std::vector _values{}; +}; + +template +inline ListOption::ListOption(std::string name, std::initializer_list &&default_values) + : Option{ std::move(name), false, true }, _values{ std::forward>(default_values) } +{ +} + +template +bool ListOption::parse(std::string value) +{ + _is_set = true; + + try + { + std::stringstream stream{ value }; + std::string item; + + while(!std::getline(stream, item, ',').fail()) + { + std::stringstream item_stream(item); + T typed_value{}; + + item_stream >> typed_value; + + if(!item_stream.fail()) + { + _values.emplace_back(typed_value); + } + + _is_set = _is_set && !item_stream.fail(); + } + + return _is_set; + } + catch(const std::invalid_argument &) + { + return false; + } +} + +template +inline std::string ListOption::help() const +{ + return "--" + name() + "=VALUE[,VALUE[,...]] - " + _help; +} + +template +inline const std::vector &ListOption::value() const +{ + return _values; +} +} // namespace utils +} // namespace arm_compute +#endif /* ARM_COMPUTE_UTILS_LISTOPTION */ diff --git a/utils/command_line/Option.h b/utils/command_line/Option.h new file mode 100644 index 0000000000..b9469a5cc3 --- /dev/null +++ b/utils/command_line/Option.h @@ -0,0 +1,141 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_UTILS_OPTIONBASE +#define ARM_COMPUTE_UTILS_OPTIONBASE + +#include + +namespace arm_compute +{ +namespace utils +{ +/** Abstract base class for a command line option. */ +class Option +{ +public: + /** Constructor. + * + * @param[in] name Name of the option. + */ + Option(std::string name); + + /** Constructor. + * + * @param[in] name Name of the option. + * @param[in] is_required Is the option required? + * @param[in] is_set Has a value been assigned to the option? + */ + Option(std::string name, bool is_required, bool is_set); + + /** Default destructor. */ + virtual ~Option() = default; + + /** Parses the given string. + * + * @param[in] value String representation as passed on the command line. + * + * @return True if the value could be parsed by the specific subclass. + */ + virtual bool parse(std::string value) = 0; + + /** Help message for the option. + * + * @return String representing the help message for the specific subclass. + */ + virtual std::string help() const = 0; + + /** Name of the option. + * + * @return Name of the option. + */ + std::string name() const; + + /** Set whether the option is required. + * + * @param[in] is_required Pass true if the option is required. + */ + void set_required(bool is_required); + + /** Set the help message for the option. + * + * @param[in] help Option specific help message. + */ + void set_help(std::string help); + + /** Is the option required? + * + * @return True if the option is required. + */ + bool is_required() const; + + /** Has a value been assigned to the option? + * + * @return True if a value has been set. + */ + bool is_set() const; + +protected: + std::string _name; + bool _is_required{ false }; + bool _is_set{ false }; + std::string _help{}; +}; + +inline Option::Option(std::string name) + : _name{ std::move(name) } +{ +} + +inline Option::Option(std::string name, bool is_required, bool is_set) + : _name{ std::move(name) }, _is_required{ is_required }, _is_set{ is_set } +{ +} + +inline std::string Option::name() const +{ + return _name; +} + +inline void Option::set_required(bool is_required) +{ + _is_required = is_required; +} + +inline void Option::set_help(std::string help) +{ + _help = std::move(help); +} + +inline bool Option::is_required() const +{ + return _is_required; +} + +inline bool Option::is_set() const +{ + return _is_set; +} +} // namespace utils +} // namespace arm_compute +#endif /* ARM_COMPUTE_UTILS_OPTIONBASE */ diff --git a/utils/command_line/SimpleOption.h b/utils/command_line/SimpleOption.h new file mode 100644 index 0000000000..543759259a --- /dev/null +++ b/utils/command_line/SimpleOption.h @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_UTILS_SIMPLEOPTION +#define ARM_COMPUTE_UTILS_SIMPLEOPTION + +#include "Option.h" + +#include +#include +#include + +namespace arm_compute +{ +namespace utils +{ +/** Implementation of an option that accepts a single value. */ +template +class SimpleOption : public Option +{ +public: + using Option::Option; + + /** Construct the option with the given default value. + * + * @param[in] name Name of the option. + * @param[in] default_value Default value. + */ + SimpleOption(std::string name, T default_value); + + /** Parses the given string. + * + * @param[in] value String representation as passed on the command line. + * + * @return True if the value could be parsed by the specific subclass. + */ + bool parse(std::string value) override; + + /** Help message for the option. + * + * @return String representing the help message for the specific subclass. + */ + std::string help() const override; + + /** Get the option value. + * + * @return the option value. + */ + const T &value() const; + +protected: + T _value{}; +}; + +template +inline SimpleOption::SimpleOption(std::string name, T default_value) + : Option{ std::move(name), false, true }, _value{ std::move(default_value) } +{ +} + +template +bool SimpleOption::parse(std::string value) +{ + try + { + std::stringstream stream{ std::move(value) }; + stream >> _value; + _is_set = !stream.fail(); + return _is_set; + } + catch(const std::invalid_argument &) + { + return false; + } +} + +template <> +inline bool SimpleOption::parse(std::string value) +{ + _value = std::move(value); + _is_set = true; + return true; +} + +template +inline std::string SimpleOption::help() const +{ + return "--" + name() + "=VALUE - " + _help; +} + +template +inline const T &SimpleOption::value() const +{ + return _value; +} +} // namespace utils +} // namespace arm_compute +#endif /* ARM_COMPUTE_UTILS_SIMPLEOPTION */ diff --git a/utils/command_line/ToggleOption.h b/utils/command_line/ToggleOption.h new file mode 100644 index 0000000000..b1d2a32c64 --- /dev/null +++ b/utils/command_line/ToggleOption.h @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_UTILS_TOGGLEOPTION +#define ARM_COMPUTE_UTILS_TOGGLEOPTION + +#include "SimpleOption.h" + +#include + +namespace arm_compute +{ +namespace utils +{ +/** Implementation of an option that can be either true or false. */ +class ToggleOption : public SimpleOption +{ +public: + using SimpleOption::SimpleOption; + + /** Construct the option with the given default value. + * + * @param[in] name Name of the option. + * @param[in] default_value Default value. + */ + ToggleOption(std::string name, bool default_value); + + bool parse(std::string value) override; + std::string help() const override; +}; + +inline ToggleOption::ToggleOption(std::string name, bool default_value) + : SimpleOption +{ + std::move(name), default_value +} +{ +} + +inline bool ToggleOption::parse(std::string value) +{ + if(value == "true") + { + _value = true; + _is_set = true; + } + else if(value == "false") + { + _value = false; + _is_set = true; + } + + return _is_set; +} + +inline std::string ToggleOption::help() const +{ + return "--" + name() + ", --no-" + name() + " - " + _help; +} +} // namespace utils +} // namespace arm_compute +#endif /* ARM_COMPUTE_UTILS_TOGGLEOPTION */ -- cgit v1.2.1