diff options
author | SiCong Li <sicong.li@arm.com> | 2019-06-24 16:03:33 +0100 |
---|---|---|
committer | sicong.li <sicong.li@arm.com> | 2019-07-15 11:05:36 +0000 |
commit | 898a324d4e5c09e53bbc5925d70577b2f45f753d (patch) | |
tree | 6bc8e8629948959ef3c7c8f1d33ac8abb2d6f6c8 /src/armnnUtils/ModelAccuracyChecker.cpp | |
parent | 454d1f5d5ad2b63ba21cc1ed4a59ac9710991f55 (diff) | |
download | armnn-898a324d4e5c09e53bbc5925d70577b2f45f753d.tar.gz |
MLCE-103 Add necessary enhancements to ModelAccuracyTool
* Evaluate model accuracy using category names instead of numerical
labels.
* Add blacklist support
* Add range selection support
Signed-off-by: SiCong Li <sicong.li@arm.com>
Change-Id: I7b1d2d298cfcaa56a27a028147169404b73580bb
Diffstat (limited to 'src/armnnUtils/ModelAccuracyChecker.cpp')
-rw-r--r-- | src/armnnUtils/ModelAccuracyChecker.cpp | 62 |
1 files changed, 53 insertions, 9 deletions
diff --git a/src/armnnUtils/ModelAccuracyChecker.cpp b/src/armnnUtils/ModelAccuracyChecker.cpp index bee5ca2365..81942dc2be 100644 --- a/src/armnnUtils/ModelAccuracyChecker.cpp +++ b/src/armnnUtils/ModelAccuracyChecker.cpp @@ -3,22 +3,27 @@ // SPDX-License-Identifier: MIT // -#include <vector> -#include <map> -#include <boost/log/trivial.hpp> #include "ModelAccuracyChecker.hpp" +#include <boost/filesystem.hpp> +#include <boost/log/trivial.hpp> +#include <map> +#include <vector> namespace armnnUtils { -armnnUtils::ModelAccuracyChecker::ModelAccuracyChecker(const std::map<std::string, int>& validationLabels) - : m_GroundTruthLabelSet(validationLabels){} +armnnUtils::ModelAccuracyChecker::ModelAccuracyChecker(const std::map<std::string, std::string>& validationLabels, + const std::vector<LabelCategoryNames>& modelOutputLabels) + : m_GroundTruthLabelSet(validationLabels) + , m_ModelOutputLabels(modelOutputLabels) +{} float ModelAccuracyChecker::GetAccuracy(unsigned int k) { - if(k > 10) { - BOOST_LOG_TRIVIAL(info) << "Accuracy Tool only supports a maximum of Top 10 Accuracy. " - "Printing Top 10 Accuracy result!"; + if (k > 10) + { + BOOST_LOG_TRIVIAL(warning) << "Accuracy Tool only supports a maximum of Top 10 Accuracy. " + "Printing Top 10 Accuracy result!"; k = 10; } unsigned int total = 0; @@ -28,4 +33,43 @@ float ModelAccuracyChecker::GetAccuracy(unsigned int k) } return static_cast<float>(total * 100) / static_cast<float>(m_ImagesProcessed); } -}
\ No newline at end of file + +// Split a string into tokens by a delimiter +std::vector<std::string> + SplitBy(const std::string& originalString, const std::string& delimiter, bool includeEmptyToken) +{ + std::vector<std::string> tokens; + size_t cur = 0; + size_t next = 0; + while ((next = originalString.find(delimiter, cur)) != std::string::npos) + { + // Skip empty tokens, unless explicitly stated to include them. + if (next - cur > 0 || includeEmptyToken) + { + tokens.push_back(originalString.substr(cur, next - cur)); + } + cur = next + delimiter.size(); + } + // Get the remaining token + // Skip empty tokens, unless explicitly stated to include them. + if (originalString.size() - cur > 0 || includeEmptyToken) + { + tokens.push_back(originalString.substr(cur, originalString.size() - cur)); + } + return tokens; +} + +// Remove any preceding and trailing character specified in the characterSet. +std::string Strip(const std::string& originalString, const std::string& characterSet) +{ + BOOST_ASSERT(!characterSet.empty()); + const std::size_t firstFound = originalString.find_first_not_of(characterSet); + const std::size_t lastFound = originalString.find_last_not_of(characterSet); + // Return empty if the originalString is empty or the originalString contains only to-be-striped characters + if (firstFound == std::string::npos || lastFound == std::string::npos) + { + return ""; + } + return originalString.substr(firstFound, lastFound + 1 - firstFound); +} +} // namespace armnnUtils
\ No newline at end of file |