ArmNN
 20.05
ModelAccuracyCheckerTest.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
6 
7 #include <boost/test/unit_test.hpp>
8 
9 #include <boost/filesystem.hpp>
10 #include <boost/optional.hpp>
11 #include <boost/variant.hpp>
12 #include <iostream>
13 #include <string>
14 
15 using namespace armnnUtils;
16 
17 namespace {
18 struct TestHelper
19 {
20  const std::map<std::string, std::string> GetValidationLabelSet()
21  {
22  std::map<std::string, std::string> validationLabelSet;
23  validationLabelSet.insert(std::make_pair("val_01.JPEG", "goldfinch"));
24  validationLabelSet.insert(std::make_pair("val_02.JPEG", "magpie"));
25  validationLabelSet.insert(std::make_pair("val_03.JPEG", "brambling"));
26  validationLabelSet.insert(std::make_pair("val_04.JPEG", "robin"));
27  validationLabelSet.insert(std::make_pair("val_05.JPEG", "indigo bird"));
28  validationLabelSet.insert(std::make_pair("val_06.JPEG", "ostrich"));
29  validationLabelSet.insert(std::make_pair("val_07.JPEG", "jay"));
30  validationLabelSet.insert(std::make_pair("val_08.JPEG", "snowbird"));
31  validationLabelSet.insert(std::make_pair("val_09.JPEG", "house finch"));
32  validationLabelSet.insert(std::make_pair("val_09.JPEG", "bulbul"));
33 
34  return validationLabelSet;
35  }
36  const std::vector<armnnUtils::LabelCategoryNames> GetModelOutputLabels()
37  {
38  const std::vector<armnnUtils::LabelCategoryNames> modelOutputLabels =
39  {
40  {"ostrich", "Struthio camelus"},
41  {"brambling", "Fringilla montifringilla"},
42  {"goldfinch", "Carduelis carduelis"},
43  {"house finch", "linnet", "Carpodacus mexicanus"},
44  {"junco", "snowbird"},
45  {"indigo bunting", "indigo finch", "indigo bird", "Passerina cyanea"},
46  {"robin", "American robin", "Turdus migratorius"},
47  {"bulbul"},
48  {"jay"},
49  {"magpie"}
50  };
51  return modelOutputLabels;
52  }
53 };
54 }
55 
56 BOOST_AUTO_TEST_SUITE(ModelAccuracyCheckerTest)
57 
58 using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
59 
60 BOOST_FIXTURE_TEST_CASE(TestFloat32OutputTensorAccuracy, TestHelper)
61 {
62  ModelAccuracyChecker checker(GetValidationLabelSet(), GetModelOutputLabels());
63 
64  // Add image 1 and check accuracy
65  std::vector<float> inferenceOutputVector1 = {0.05f, 0.10f, 0.70f, 0.15f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
66  TContainer inference1Container(inferenceOutputVector1);
67  std::vector<TContainer> outputTensor1;
68  outputTensor1.push_back(inference1Container);
69 
70  std::string imageName = "val_01.JPEG";
71  checker.AddImageResult<TContainer>(imageName, outputTensor1);
72 
73  // Top 1 Accuracy
74  float totalAccuracy = checker.GetAccuracy(1);
75  BOOST_CHECK(totalAccuracy == 100.0f);
76 
77  // Add image 2 and check accuracy
78  std::vector<float> inferenceOutputVector2 = {0.10f, 0.0f, 0.0f, 0.0f, 0.05f, 0.70f, 0.0f, 0.0f, 0.0f, 0.15f};
79  TContainer inference2Container(inferenceOutputVector2);
80  std::vector<TContainer> outputTensor2;
81  outputTensor2.push_back(inference2Container);
82 
83  imageName = "val_02.JPEG";
84  checker.AddImageResult<TContainer>(imageName, outputTensor2);
85 
86  // Top 1 Accuracy
87  totalAccuracy = checker.GetAccuracy(1);
88  BOOST_CHECK(totalAccuracy == 50.0f);
89 
90  // Top 2 Accuracy
91  totalAccuracy = checker.GetAccuracy(2);
92  BOOST_CHECK(totalAccuracy == 100.0f);
93 
94  // Add image 3 and check accuracy
95  std::vector<float> inferenceOutputVector3 = {0.0f, 0.10f, 0.0f, 0.0f, 0.05f, 0.70f, 0.0f, 0.0f, 0.0f, 0.15f};
96  TContainer inference3Container(inferenceOutputVector3);
97  std::vector<TContainer> outputTensor3;
98  outputTensor3.push_back(inference3Container);
99 
100  imageName = "val_03.JPEG";
101  checker.AddImageResult<TContainer>(imageName, outputTensor3);
102 
103  // Top 1 Accuracy
104  totalAccuracy = checker.GetAccuracy(1);
105  BOOST_CHECK(totalAccuracy == 33.3333321f);
106 
107  // Top 2 Accuracy
108  totalAccuracy = checker.GetAccuracy(2);
109  BOOST_CHECK(totalAccuracy == 66.6666641f);
110 
111  // Top 3 Accuracy
112  totalAccuracy = checker.GetAccuracy(3);
113  BOOST_CHECK(totalAccuracy == 100.0f);
114 }
115 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
BOOST_FIXTURE_TEST_CASE(TestFloat32OutputTensorAccuracy, TestHelper)
float GetAccuracy(unsigned int k)
Get Top K accuracy.
BOOST_CHECK(profilingService.GetCurrentState()==ProfilingState::WaitingForAck)
void AddImageResult(const std::string &imageName, std::vector< TContainer > outputTensor)
Record the prediction result of an image.
boost::variant< std::vector< float >, std::vector< int >, std::vector< unsigned char > > TContainer
BOOST_AUTO_TEST_SUITE_END()