aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils/ModelAccuracyChecker.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnUtils/ModelAccuracyChecker.cpp')
-rw-r--r--src/armnnUtils/ModelAccuracyChecker.cpp62
1 files changed, 53 insertions, 9 deletions
diff --git a/src/armnnUtils/ModelAccuracyChecker.cpp b/src/armnnUtils/ModelAccuracyChecker.cpp
index bee5ca2365..81942dc2be 100644
--- a/src/armnnUtils/ModelAccuracyChecker.cpp
+++ b/src/armnnUtils/ModelAccuracyChecker.cpp
@@ -3,22 +3,27 @@
// SPDX-License-Identifier: MIT
//
-#include <vector>
-#include <map>
-#include <boost/log/trivial.hpp>
#include "ModelAccuracyChecker.hpp"
+#include <boost/filesystem.hpp>
+#include <boost/log/trivial.hpp>
+#include <map>
+#include <vector>
namespace armnnUtils
{
-armnnUtils::ModelAccuracyChecker::ModelAccuracyChecker(const std::map<std::string, int>& validationLabels)
- : m_GroundTruthLabelSet(validationLabels){}
+armnnUtils::ModelAccuracyChecker::ModelAccuracyChecker(const std::map<std::string, std::string>& validationLabels,
+ const std::vector<LabelCategoryNames>& modelOutputLabels)
+ : m_GroundTruthLabelSet(validationLabels)
+ , m_ModelOutputLabels(modelOutputLabels)
+{}
float ModelAccuracyChecker::GetAccuracy(unsigned int k)
{
- if(k > 10) {
- BOOST_LOG_TRIVIAL(info) << "Accuracy Tool only supports a maximum of Top 10 Accuracy. "
- "Printing Top 10 Accuracy result!";
+ if (k > 10)
+ {
+ BOOST_LOG_TRIVIAL(warning) << "Accuracy Tool only supports a maximum of Top 10 Accuracy. "
+ "Printing Top 10 Accuracy result!";
k = 10;
}
unsigned int total = 0;
@@ -28,4 +33,43 @@ float ModelAccuracyChecker::GetAccuracy(unsigned int k)
}
return static_cast<float>(total * 100) / static_cast<float>(m_ImagesProcessed);
}
-} \ No newline at end of file
+
+// Split a string into tokens by a delimiter
+std::vector<std::string>
+ SplitBy(const std::string& originalString, const std::string& delimiter, bool includeEmptyToken)
+{
+ std::vector<std::string> tokens;
+ size_t cur = 0;
+ size_t next = 0;
+ while ((next = originalString.find(delimiter, cur)) != std::string::npos)
+ {
+ // Skip empty tokens, unless explicitly stated to include them.
+ if (next - cur > 0 || includeEmptyToken)
+ {
+ tokens.push_back(originalString.substr(cur, next - cur));
+ }
+ cur = next + delimiter.size();
+ }
+ // Get the remaining token
+ // Skip empty tokens, unless explicitly stated to include them.
+ if (originalString.size() - cur > 0 || includeEmptyToken)
+ {
+ tokens.push_back(originalString.substr(cur, originalString.size() - cur));
+ }
+ return tokens;
+}
+
+// Remove any preceding and trailing character specified in the characterSet.
+std::string Strip(const std::string& originalString, const std::string& characterSet)
+{
+ BOOST_ASSERT(!characterSet.empty());
+ const std::size_t firstFound = originalString.find_first_not_of(characterSet);
+ const std::size_t lastFound = originalString.find_last_not_of(characterSet);
+ // Return empty if the originalString is empty or the originalString contains only to-be-striped characters
+ if (firstFound == std::string::npos || lastFound == std::string::npos)
+ {
+ return "";
+ }
+ return originalString.substr(firstFound, lastFound + 1 - firstFound);
+}
+} // namespace armnnUtils \ No newline at end of file