aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorSiCong Li <sicong.li@arm.com>2019-06-24 16:03:33 +0100
committersicong.li <sicong.li@arm.com>2019-07-15 11:05:36 +0000
commit898a324d4e5c09e53bbc5925d70577b2f45f753d (patch)
tree6bc8e8629948959ef3c7c8f1d33ac8abb2d6f6c8 /src
parent454d1f5d5ad2b63ba21cc1ed4a59ac9710991f55 (diff)
downloadarmnn-898a324d4e5c09e53bbc5925d70577b2f45f753d.tar.gz
MLCE-103 Add necessary enhancements to ModelAccuracyTool
* Evaluate model accuracy using category names instead of numerical labels. * Add blacklist support * Add range selection support Signed-off-by: SiCong Li <sicong.li@arm.com> Change-Id: I7b1d2d298cfcaa56a27a028147169404b73580bb
Diffstat (limited to 'src')
-rw-r--r--src/armnn/test/ModelAccuracyCheckerTest.cpp58
-rw-r--r--src/armnnUtils/ModelAccuracyChecker.cpp62
-rw-r--r--src/armnnUtils/ModelAccuracyChecker.hpp93
3 files changed, 153 insertions, 60 deletions
diff --git a/src/armnn/test/ModelAccuracyCheckerTest.cpp b/src/armnn/test/ModelAccuracyCheckerTest.cpp
index f3a6c9d81d..aa1fba212c 100644
--- a/src/armnn/test/ModelAccuracyCheckerTest.cpp
+++ b/src/armnn/test/ModelAccuracyCheckerTest.cpp
@@ -7,32 +7,50 @@
#include <boost/algorithm/string.hpp>
#include <boost/test/unit_test.hpp>
-#include <iostream>
-#include <string>
-#include <boost/log/core/core.hpp>
#include <boost/filesystem.hpp>
+#include <boost/log/core/core.hpp>
#include <boost/optional.hpp>
#include <boost/variant.hpp>
+#include <iostream>
+#include <string>
using namespace armnnUtils;
-struct TestHelper {
- const std::map<std::string, int> GetValidationLabelSet()
+struct TestHelper
+{
+ const std::map<std::string, std::string> GetValidationLabelSet()
{
- std::map<std::string, int> validationLabelSet;
- validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000001", 2));
- validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000002", 9));
- validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000003", 1));
- validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000004", 6));
- validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000005", 5));
- validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000006", 0));
- validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000007", 8));
- validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000008", 4));
- validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000009", 3));
- validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000009", 7));
+ std::map<std::string, std::string> validationLabelSet;
+ validationLabelSet.insert(std::make_pair("val_01.JPEG", "goldfinch"));
+ validationLabelSet.insert(std::make_pair("val_02.JPEG", "magpie"));
+ validationLabelSet.insert(std::make_pair("val_03.JPEG", "brambling"));
+ validationLabelSet.insert(std::make_pair("val_04.JPEG", "robin"));
+ validationLabelSet.insert(std::make_pair("val_05.JPEG", "indigo bird"));
+ validationLabelSet.insert(std::make_pair("val_06.JPEG", "ostrich"));
+ validationLabelSet.insert(std::make_pair("val_07.JPEG", "jay"));
+ validationLabelSet.insert(std::make_pair("val_08.JPEG", "snowbird"));
+ validationLabelSet.insert(std::make_pair("val_09.JPEG", "house finch"));
+ validationLabelSet.insert(std::make_pair("val_09.JPEG", "bulbul"));
return validationLabelSet;
}
+ const std::vector<armnnUtils::LabelCategoryNames> GetModelOutputLabels()
+ {
+ const std::vector<armnnUtils::LabelCategoryNames> modelOutputLabels =
+ {
+ {"ostrich", "Struthio camelus"},
+ {"brambling", "Fringilla montifringilla"},
+ {"goldfinch", "Carduelis carduelis"},
+ {"house finch", "linnet", "Carpodacus mexicanus"},
+ {"junco", "snowbird"},
+ {"indigo bunting", "indigo finch", "indigo bird", "Passerina cyanea"},
+ {"robin", "American robin", "Turdus migratorius"},
+ {"bulbul"},
+ {"jay"},
+ {"magpie"}
+ };
+ return modelOutputLabels;
+ }
};
BOOST_AUTO_TEST_SUITE(ModelAccuracyCheckerTest)
@@ -41,7 +59,7 @@ using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vec
BOOST_FIXTURE_TEST_CASE(TestFloat32OutputTensorAccuracy, TestHelper)
{
- ModelAccuracyChecker checker(GetValidationLabelSet());
+ ModelAccuracyChecker checker(GetValidationLabelSet(), GetModelOutputLabels());
// Add image 1 and check accuracy
std::vector<float> inferenceOutputVector1 = {0.05f, 0.10f, 0.70f, 0.15f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
@@ -49,7 +67,7 @@ BOOST_FIXTURE_TEST_CASE(TestFloat32OutputTensorAccuracy, TestHelper)
std::vector<TContainer> outputTensor1;
outputTensor1.push_back(inference1Container);
- std::string imageName = "ILSVRC2012_val_00000001.JPEG";
+ std::string imageName = "val_01.JPEG";
checker.AddImageResult<TContainer>(imageName, outputTensor1);
// Top 1 Accuracy
@@ -62,7 +80,7 @@ BOOST_FIXTURE_TEST_CASE(TestFloat32OutputTensorAccuracy, TestHelper)
std::vector<TContainer> outputTensor2;
outputTensor2.push_back(inference2Container);
- imageName = "ILSVRC2012_val_00000002.JPEG";
+ imageName = "val_02.JPEG";
checker.AddImageResult<TContainer>(imageName, outputTensor2);
// Top 1 Accuracy
@@ -79,7 +97,7 @@ BOOST_FIXTURE_TEST_CASE(TestFloat32OutputTensorAccuracy, TestHelper)
std::vector<TContainer> outputTensor3;
outputTensor3.push_back(inference3Container);
- imageName = "ILSVRC2012_val_00000003.JPEG";
+ imageName = "val_03.JPEG";
checker.AddImageResult<TContainer>(imageName, outputTensor3);
// Top 1 Accuracy
diff --git a/src/armnnUtils/ModelAccuracyChecker.cpp b/src/armnnUtils/ModelAccuracyChecker.cpp
index bee5ca2365..81942dc2be 100644
--- a/src/armnnUtils/ModelAccuracyChecker.cpp
+++ b/src/armnnUtils/ModelAccuracyChecker.cpp
@@ -3,22 +3,27 @@
// SPDX-License-Identifier: MIT
//
-#include <vector>
-#include <map>
-#include <boost/log/trivial.hpp>
#include "ModelAccuracyChecker.hpp"
+#include <boost/filesystem.hpp>
+#include <boost/log/trivial.hpp>
+#include <map>
+#include <vector>
namespace armnnUtils
{
-armnnUtils::ModelAccuracyChecker::ModelAccuracyChecker(const std::map<std::string, int>& validationLabels)
- : m_GroundTruthLabelSet(validationLabels){}
+armnnUtils::ModelAccuracyChecker::ModelAccuracyChecker(const std::map<std::string, std::string>& validationLabels,
+ const std::vector<LabelCategoryNames>& modelOutputLabels)
+ : m_GroundTruthLabelSet(validationLabels)
+ , m_ModelOutputLabels(modelOutputLabels)
+{}
float ModelAccuracyChecker::GetAccuracy(unsigned int k)
{
- if(k > 10) {
- BOOST_LOG_TRIVIAL(info) << "Accuracy Tool only supports a maximum of Top 10 Accuracy. "
- "Printing Top 10 Accuracy result!";
+ if (k > 10)
+ {
+ BOOST_LOG_TRIVIAL(warning) << "Accuracy Tool only supports a maximum of Top 10 Accuracy. "
+ "Printing Top 10 Accuracy result!";
k = 10;
}
unsigned int total = 0;
@@ -28,4 +33,43 @@ float ModelAccuracyChecker::GetAccuracy(unsigned int k)
}
return static_cast<float>(total * 100) / static_cast<float>(m_ImagesProcessed);
}
-} \ No newline at end of file
+
+// Split a string into tokens by a delimiter
+std::vector<std::string>
+ SplitBy(const std::string& originalString, const std::string& delimiter, bool includeEmptyToken)
+{
+ std::vector<std::string> tokens;
+ size_t cur = 0;
+ size_t next = 0;
+ while ((next = originalString.find(delimiter, cur)) != std::string::npos)
+ {
+ // Skip empty tokens, unless explicitly stated to include them.
+ if (next - cur > 0 || includeEmptyToken)
+ {
+ tokens.push_back(originalString.substr(cur, next - cur));
+ }
+ cur = next + delimiter.size();
+ }
+ // Get the remaining token
+ // Skip empty tokens, unless explicitly stated to include them.
+ if (originalString.size() - cur > 0 || includeEmptyToken)
+ {
+ tokens.push_back(originalString.substr(cur, originalString.size() - cur));
+ }
+ return tokens;
+}
+
+// Remove any preceding and trailing character specified in the characterSet.
+std::string Strip(const std::string& originalString, const std::string& characterSet)
+{
+ BOOST_ASSERT(!characterSet.empty());
+ const std::size_t firstFound = originalString.find_first_not_of(characterSet);
+ const std::size_t lastFound = originalString.find_last_not_of(characterSet);
+ // Return empty if the originalString is empty or the originalString contains only to-be-striped characters
+ if (firstFound == std::string::npos || lastFound == std::string::npos)
+ {
+ return "";
+ }
+ return originalString.substr(firstFound, lastFound + 1 - firstFound);
+}
+} // namespace armnnUtils \ No newline at end of file
diff --git a/src/armnnUtils/ModelAccuracyChecker.hpp b/src/armnnUtils/ModelAccuracyChecker.hpp
index cdd2af0ac5..c4dd4f1b05 100644
--- a/src/armnnUtils/ModelAccuracyChecker.hpp
+++ b/src/armnnUtils/ModelAccuracyChecker.hpp
@@ -5,39 +5,81 @@
#pragma once
+#include <algorithm>
+#include <armnn/Types.hpp>
+#include <boost/assert.hpp>
+#include <boost/variant/apply_visitor.hpp>
#include <cstddef>
-#include <string>
+#include <functional>
+#include <iostream>
#include <map>
+#include <string>
#include <vector>
-#include <boost/variant/apply_visitor.hpp>
-#include <iostream>
-#include <armnn/Types.hpp>
-#include <functional>
-#include <algorithm>
namespace armnnUtils
{
using namespace armnn;
+// Category names associated with a label
+using LabelCategoryNames = std::vector<std::string>;
+
+/** Split a string into tokens by a delimiter
+ *
+ * @param[in] originalString Original string to be split
+ * @param[in] delimiter Delimiter used to split \p originalString
+ * @param[in] includeEmptyToekn If true, include empty tokens in the result
+ * @return A vector of tokens split from \p originalString by \delimiter
+ */
+std::vector<std::string>
+ SplitBy(const std::string& originalString, const std::string& delimiter = " ", bool includeEmptyToken = false);
+
+/** Remove any preceding and trailing character specified in the characterSet.
+ *
+ * @param[in] originalString Original string to be stripped
+ * @param[in] characterSet Set of characters to be stripped from \p originalString
+ * @return A string stripped of all characters specified in \p characterSet from \p originalString
+ */
+std::string Strip(const std::string& originalString, const std::string& characterSet = " ");
+
class ModelAccuracyChecker
{
public:
- ModelAccuracyChecker(const std::map<std::string, int>& validationLabelSet);
-
+ /** Constructor for a model top k accuracy checker
+ *
+ * @param[in] validationLabelSet Mapping from names of images to be validated, to category names of their
+ corresponding ground-truth labels.
+ * @param[in] modelOutputLabels Mapping from output nodes to the category names of their corresponding labels
+ Note that an output node can have multiple category names.
+ */
+ ModelAccuracyChecker(const std::map<std::string, std::string>& validationLabelSet,
+ const std::vector<LabelCategoryNames>& modelOutputLabels);
+
+ /** Get Top K accuracy
+ *
+ * @param[in] k The number of top predictions to use for validating the ground-truth label. For example, if \p k is
+ 3, then a prediction is considered correct as long as the ground-truth appears in the top 3
+ predictions.
+ * @return The accuracy, according to the top \p k th predictions.
+ */
float GetAccuracy(unsigned int k);
- template<typename TContainer>
+ /** Record the prediction result of an image
+ *
+ * @param[in] imageName Name of the image.
+ * @param[in] outputTensor Output tensor of the network running \p imageName.
+ */
+ template <typename TContainer>
void AddImageResult(const std::string& imageName, std::vector<TContainer> outputTensor)
{
// Increment the total number of images processed
++m_ImagesProcessed;
std::map<int, float> confidenceMap;
- auto & output = outputTensor[0];
+ auto& output = outputTensor[0];
// Create a map of all predictions
- boost::apply_visitor([&](auto && value)
+ boost::apply_visitor([&confidenceMap](auto && value)
{
int index = 0;
for (const auto & o : value)
@@ -64,8 +106,7 @@ public:
std::set<std::pair<int, float>, Comparator> setOfPredictions(
confidenceMap.begin(), confidenceMap.end(), compFunctor);
- std::string trimmedName = GetTrimmedImageName(imageName);
- int value = m_GroundTruthLabelSet.find(trimmedName)->second;
+ const std::string correctLabel = m_GroundTruthLabelSet.at(imageName);
unsigned int index = 1;
for (std::pair<int, float> element : setOfPredictions)
@@ -74,7 +115,10 @@ public:
{
break;
}
- if (element.first == value)
+ // Check if the ground truth label value is included in the topi prediction.
+ // Note that a prediction can have multiple prediction labels.
+ const LabelCategoryNames predictionLabels = m_ModelOutputLabels[static_cast<size_t>(element.first)];
+ if (std::find(predictionLabels.begin(), predictionLabels.end(), correctLabel) != predictionLabels.end())
{
++m_TopK[index];
break;
@@ -83,24 +127,11 @@ public:
}
}
- std::string GetTrimmedImageName(const std::string& imageName) const
- {
- std::string trimmedName;
- size_t lastindex = imageName.find_last_of(".");
- if(lastindex != std::string::npos)
- {
- trimmedName = imageName.substr(0, lastindex);
- } else
- {
- trimmedName = imageName;
- }
- return trimmedName;
- }
-
private:
- const std::map<std::string, int> m_GroundTruthLabelSet;
- std::vector<unsigned int> m_TopK = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
- unsigned int m_ImagesProcessed = 0;
+ const std::map<std::string, std::string> m_GroundTruthLabelSet;
+ const std::vector<LabelCategoryNames> m_ModelOutputLabels;
+ std::vector<unsigned int> m_TopK = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
+ unsigned int m_ImagesProcessed = 0;
};
} //namespace armnnUtils