ArmNN
 21.02
QuantizationDataSetTests.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
7 
8 #include "../QuantizationDataSet.hpp"
9 
10 #include <armnn/Optional.hpp>
11 #include <Filesystem.hpp>
12 #include <iostream>
13 #include <fstream>
14 #include <vector>
15 #include <map>
16 
17 
18 using namespace armnnQuantizer;
19 
20 struct CsvTestHelper {
21 
22  CsvTestHelper()
23  {
24  BOOST_TEST_MESSAGE("setup fixture");
25  }
26 
27  ~CsvTestHelper()
28  {
29  BOOST_TEST_MESSAGE("teardown fixture");
30  TearDown();
31  }
32 
33  std::string CreateTempCsvFile(std::map<int, std::vector<float>> csvData)
34  {
35  fs::path fileDir = fs::temp_directory_path();
36  fs::path p = armnnUtils::Filesystem::NamedTempFile("Armnn-QuantizationCreateTempCsvFileTest-TempFile.csv");
37 
38  fs::path tensorInput1{fileDir / "input_0_0.raw"};
39  fs::path tensorInput2{fileDir / "input_1_0.raw"};
40  fs::path tensorInput3{fileDir / "input_2_0.raw"};
41 
42  try
43  {
44  std::ofstream ofs{p};
45 
46  std::ofstream ofs1{tensorInput1};
47  std::ofstream ofs2{tensorInput2};
48  std::ofstream ofs3{tensorInput3};
49 
50 
51  for(auto entry : csvData.at(0))
52  {
53  ofs1 << entry << " ";
54  }
55  for(auto entry : csvData.at(1))
56  {
57  ofs2 << entry << " ";
58  }
59  for(auto entry : csvData.at(2))
60  {
61  ofs3 << entry << " ";
62  }
63 
64  ofs << "0, 0, " << tensorInput1.c_str() << std::endl;
65  ofs << "2, 0, " << tensorInput3.c_str() << std::endl;
66  ofs << "1, 0, " << tensorInput2.c_str() << std::endl;
67 
68  ofs.close();
69  ofs1.close();
70  ofs2.close();
71  ofs3.close();
72  }
73  catch (std::exception &e)
74  {
75  std::cerr << "Unable to write to file at location [" << p.c_str() << "] : " << e.what() << std::endl;
76  BOOST_TEST(false);
77  }
78 
79  m_CsvFile = p;
80  return p.string();
81  }
82 
83  void TearDown()
84  {
85  RemoveCsvFile();
86  }
87 
88  void RemoveCsvFile()
89  {
90  if (m_CsvFile)
91  {
92  try
93  {
94  fs::remove(m_CsvFile.value());
95  }
96  catch (std::exception &e)
97  {
98  std::cerr << "Unable to delete file [" << m_CsvFile.value() << "] : " << e.what() << std::endl;
99  BOOST_TEST(false);
100  }
101  }
102  }
103 
104  armnn::Optional<fs::path> m_CsvFile;
105 };
106 
107 
108 BOOST_AUTO_TEST_SUITE(QuantizationDataSetTests)
109 
110 BOOST_FIXTURE_TEST_CASE(CheckDataSet, CsvTestHelper)
111 {
112 
113  std::map<int, std::vector<float>> csvData;
114  csvData.insert(std::pair<int, std::vector<float>>(0, { 0.111111f, 0.222222f, 0.333333f }));
115  csvData.insert(std::pair<int, std::vector<float>>(1, { 0.444444f, 0.555555f, 0.666666f }));
116  csvData.insert(std::pair<int, std::vector<float>>(2, { 0.777777f, 0.888888f, 0.999999f }));
117 
118  std::string myCsvFile = CsvTestHelper::CreateTempCsvFile(csvData);
119  QuantizationDataSet dataSet(myCsvFile);
120  BOOST_TEST(!dataSet.IsEmpty());
121 
122  int csvRow = 0;
123  for(armnnQuantizer::QuantizationInput input : dataSet)
124  {
125  BOOST_TEST(input.GetPassId() == csvRow);
126 
127  BOOST_TEST(input.GetLayerBindingIds().size() == 1);
128  BOOST_TEST(input.GetLayerBindingIds()[0] == 0);
129  BOOST_TEST(input.GetDataForEntry(0).size() == 3);
130 
131  // Check that QuantizationInput data for binding ID 0 corresponds to float values
132  // used for populating the CSV file using by QuantizationDataSet
133  BOOST_TEST(input.GetDataForEntry(0).at(0) == csvData.at(csvRow).at(0));
134  BOOST_TEST(input.GetDataForEntry(0).at(1) == csvData.at(csvRow).at(1));
135  BOOST_TEST(input.GetDataForEntry(0).at(2) == csvData.at(csvRow).at(2));
136  ++csvRow;
137  }
138 }
139 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
QuantizationDataSet is a structure which is created after parsing a quantization CSV file...
QuantizationInput for specific pass ID, can list a corresponding raw data file for each LayerBindingI...
BOOST_AUTO_TEST_SUITE_END()
fs::path NamedTempFile(const char *fileName)
Construct a temporary file name.
Definition: Filesystem.cpp:23
BOOST_FIXTURE_TEST_CASE(CheckDataSet, CsvTestHelper)