aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/ModelAccuracyCheckerTest.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/ModelAccuracyCheckerTest.cpp')
-rw-r--r--src/armnn/test/ModelAccuracyCheckerTest.cpp58
1 files changed, 38 insertions, 20 deletions
diff --git a/src/armnn/test/ModelAccuracyCheckerTest.cpp b/src/armnn/test/ModelAccuracyCheckerTest.cpp
index f3a6c9d81d..aa1fba212c 100644
--- a/src/armnn/test/ModelAccuracyCheckerTest.cpp
+++ b/src/armnn/test/ModelAccuracyCheckerTest.cpp
@@ -7,32 +7,50 @@
#include <boost/algorithm/string.hpp>
#include <boost/test/unit_test.hpp>
-#include <iostream>
-#include <string>
-#include <boost/log/core/core.hpp>
#include <boost/filesystem.hpp>
+#include <boost/log/core/core.hpp>
#include <boost/optional.hpp>
#include <boost/variant.hpp>
+#include <iostream>
+#include <string>
using namespace armnnUtils;
-struct TestHelper {
- const std::map<std::string, int> GetValidationLabelSet()
+struct TestHelper
+{
+ const std::map<std::string, std::string> GetValidationLabelSet()
{
- std::map<std::string, int> validationLabelSet;
- validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000001", 2));
- validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000002", 9));
- validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000003", 1));
- validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000004", 6));
- validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000005", 5));
- validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000006", 0));
- validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000007", 8));
- validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000008", 4));
- validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000009", 3));
- validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000009", 7));
+ std::map<std::string, std::string> validationLabelSet;
+ validationLabelSet.insert(std::make_pair("val_01.JPEG", "goldfinch"));
+ validationLabelSet.insert(std::make_pair("val_02.JPEG", "magpie"));
+ validationLabelSet.insert(std::make_pair("val_03.JPEG", "brambling"));
+ validationLabelSet.insert(std::make_pair("val_04.JPEG", "robin"));
+ validationLabelSet.insert(std::make_pair("val_05.JPEG", "indigo bird"));
+ validationLabelSet.insert(std::make_pair("val_06.JPEG", "ostrich"));
+ validationLabelSet.insert(std::make_pair("val_07.JPEG", "jay"));
+ validationLabelSet.insert(std::make_pair("val_08.JPEG", "snowbird"));
+ validationLabelSet.insert(std::make_pair("val_09.JPEG", "house finch"));
+ validationLabelSet.insert(std::make_pair("val_09.JPEG", "bulbul"));
return validationLabelSet;
}
+ const std::vector<armnnUtils::LabelCategoryNames> GetModelOutputLabels()
+ {
+ const std::vector<armnnUtils::LabelCategoryNames> modelOutputLabels =
+ {
+ {"ostrich", "Struthio camelus"},
+ {"brambling", "Fringilla montifringilla"},
+ {"goldfinch", "Carduelis carduelis"},
+ {"house finch", "linnet", "Carpodacus mexicanus"},
+ {"junco", "snowbird"},
+ {"indigo bunting", "indigo finch", "indigo bird", "Passerina cyanea"},
+ {"robin", "American robin", "Turdus migratorius"},
+ {"bulbul"},
+ {"jay"},
+ {"magpie"}
+ };
+ return modelOutputLabels;
+ }
};
BOOST_AUTO_TEST_SUITE(ModelAccuracyCheckerTest)
@@ -41,7 +59,7 @@ using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vec
BOOST_FIXTURE_TEST_CASE(TestFloat32OutputTensorAccuracy, TestHelper)
{
- ModelAccuracyChecker checker(GetValidationLabelSet());
+ ModelAccuracyChecker checker(GetValidationLabelSet(), GetModelOutputLabels());
// Add image 1 and check accuracy
std::vector<float> inferenceOutputVector1 = {0.05f, 0.10f, 0.70f, 0.15f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
@@ -49,7 +67,7 @@ BOOST_FIXTURE_TEST_CASE(TestFloat32OutputTensorAccuracy, TestHelper)
std::vector<TContainer> outputTensor1;
outputTensor1.push_back(inference1Container);
- std::string imageName = "ILSVRC2012_val_00000001.JPEG";
+ std::string imageName = "val_01.JPEG";
checker.AddImageResult<TContainer>(imageName, outputTensor1);
// Top 1 Accuracy
@@ -62,7 +80,7 @@ BOOST_FIXTURE_TEST_CASE(TestFloat32OutputTensorAccuracy, TestHelper)
std::vector<TContainer> outputTensor2;
outputTensor2.push_back(inference2Container);
- imageName = "ILSVRC2012_val_00000002.JPEG";
+ imageName = "val_02.JPEG";
checker.AddImageResult<TContainer>(imageName, outputTensor2);
// Top 1 Accuracy
@@ -79,7 +97,7 @@ BOOST_FIXTURE_TEST_CASE(TestFloat32OutputTensorAccuracy, TestHelper)
std::vector<TContainer> outputTensor3;
outputTensor3.push_back(inference3Container);
- imageName = "ILSVRC2012_val_00000003.JPEG";
+ imageName = "val_03.JPEG";
checker.AddImageResult<TContainer>(imageName, outputTensor3);
// Top 1 Accuracy