diff options
author | SiCong Li <sicong.li@arm.com> | 2019-06-24 16:03:33 +0100 |
---|---|---|
committer | sicong.li <sicong.li@arm.com> | 2019-07-15 11:05:36 +0000 |
commit | 898a324d4e5c09e53bbc5925d70577b2f45f753d (patch) | |
tree | 6bc8e8629948959ef3c7c8f1d33ac8abb2d6f6c8 /src/armnn/test/ModelAccuracyCheckerTest.cpp | |
parent | 454d1f5d5ad2b63ba21cc1ed4a59ac9710991f55 (diff) | |
download | armnn-898a324d4e5c09e53bbc5925d70577b2f45f753d.tar.gz |
MLCE-103 Add necessary enhancements to ModelAccuracyTool
* Evaluate model accuracy using category names instead of numerical
labels.
* Add blacklist support
* Add range selection support
Signed-off-by: SiCong Li <sicong.li@arm.com>
Change-Id: I7b1d2d298cfcaa56a27a028147169404b73580bb
Diffstat (limited to 'src/armnn/test/ModelAccuracyCheckerTest.cpp')
-rw-r--r-- | src/armnn/test/ModelAccuracyCheckerTest.cpp | 58 |
1 files changed, 38 insertions, 20 deletions
diff --git a/src/armnn/test/ModelAccuracyCheckerTest.cpp b/src/armnn/test/ModelAccuracyCheckerTest.cpp index f3a6c9d81d..aa1fba212c 100644 --- a/src/armnn/test/ModelAccuracyCheckerTest.cpp +++ b/src/armnn/test/ModelAccuracyCheckerTest.cpp @@ -7,32 +7,50 @@ #include <boost/algorithm/string.hpp> #include <boost/test/unit_test.hpp> -#include <iostream> -#include <string> -#include <boost/log/core/core.hpp> #include <boost/filesystem.hpp> +#include <boost/log/core/core.hpp> #include <boost/optional.hpp> #include <boost/variant.hpp> +#include <iostream> +#include <string> using namespace armnnUtils; -struct TestHelper { - const std::map<std::string, int> GetValidationLabelSet() +struct TestHelper +{ + const std::map<std::string, std::string> GetValidationLabelSet() { - std::map<std::string, int> validationLabelSet; - validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000001", 2)); - validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000002", 9)); - validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000003", 1)); - validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000004", 6)); - validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000005", 5)); - validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000006", 0)); - validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000007", 8)); - validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000008", 4)); - validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000009", 3)); - validationLabelSet.insert( std::make_pair("ILSVRC2012_val_00000009", 7)); + std::map<std::string, std::string> validationLabelSet; + validationLabelSet.insert(std::make_pair("val_01.JPEG", "goldfinch")); + validationLabelSet.insert(std::make_pair("val_02.JPEG", "magpie")); + validationLabelSet.insert(std::make_pair("val_03.JPEG", "brambling")); + validationLabelSet.insert(std::make_pair("val_04.JPEG", "robin")); + validationLabelSet.insert(std::make_pair("val_05.JPEG", "indigo bird")); + validationLabelSet.insert(std::make_pair("val_06.JPEG", "ostrich")); + validationLabelSet.insert(std::make_pair("val_07.JPEG", "jay")); + validationLabelSet.insert(std::make_pair("val_08.JPEG", "snowbird")); + validationLabelSet.insert(std::make_pair("val_09.JPEG", "house finch")); + validationLabelSet.insert(std::make_pair("val_09.JPEG", "bulbul")); return validationLabelSet; } + const std::vector<armnnUtils::LabelCategoryNames> GetModelOutputLabels() + { + const std::vector<armnnUtils::LabelCategoryNames> modelOutputLabels = + { + {"ostrich", "Struthio camelus"}, + {"brambling", "Fringilla montifringilla"}, + {"goldfinch", "Carduelis carduelis"}, + {"house finch", "linnet", "Carpodacus mexicanus"}, + {"junco", "snowbird"}, + {"indigo bunting", "indigo finch", "indigo bird", "Passerina cyanea"}, + {"robin", "American robin", "Turdus migratorius"}, + {"bulbul"}, + {"jay"}, + {"magpie"} + }; + return modelOutputLabels; + } }; BOOST_AUTO_TEST_SUITE(ModelAccuracyCheckerTest) @@ -41,7 +59,7 @@ using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vec BOOST_FIXTURE_TEST_CASE(TestFloat32OutputTensorAccuracy, TestHelper) { - ModelAccuracyChecker checker(GetValidationLabelSet()); + ModelAccuracyChecker checker(GetValidationLabelSet(), GetModelOutputLabels()); // Add image 1 and check accuracy std::vector<float> inferenceOutputVector1 = {0.05f, 0.10f, 0.70f, 0.15f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; @@ -49,7 +67,7 @@ BOOST_FIXTURE_TEST_CASE(TestFloat32OutputTensorAccuracy, TestHelper) std::vector<TContainer> outputTensor1; outputTensor1.push_back(inference1Container); - std::string imageName = "ILSVRC2012_val_00000001.JPEG"; + std::string imageName = "val_01.JPEG"; checker.AddImageResult<TContainer>(imageName, outputTensor1); // Top 1 Accuracy @@ -62,7 +80,7 @@ BOOST_FIXTURE_TEST_CASE(TestFloat32OutputTensorAccuracy, TestHelper) std::vector<TContainer> outputTensor2; outputTensor2.push_back(inference2Container); - imageName = "ILSVRC2012_val_00000002.JPEG"; + imageName = "val_02.JPEG"; checker.AddImageResult<TContainer>(imageName, outputTensor2); // Top 1 Accuracy @@ -79,7 +97,7 @@ BOOST_FIXTURE_TEST_CASE(TestFloat32OutputTensorAccuracy, TestHelper) std::vector<TContainer> outputTensor3; outputTensor3.push_back(inference3Container); - imageName = "ILSVRC2012_val_00000003.JPEG"; + imageName = "val_03.JPEG"; checker.AddImageResult<TContainer>(imageName, outputTensor3); // Top 1 Accuracy |