ArmNN
 20.11
QuantizationDataSet.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
7 
8 #include <fmt/format.h>
9 
12 #include <Filesystem.hpp>
13 
14 namespace armnnQuantizer
15 {
16 
18 {
19 }
20 
21 QuantizationDataSet::QuantizationDataSet(const std::string csvFilePath):
22  m_QuantizationInputs(),
23  m_CsvFilePath(csvFilePath)
24 {
25  ParseCsvFile();
26 }
27 
28 void AddInputData(unsigned int passId,
29  armnn::LayerBindingId bindingId,
30  const std::string& inputFilePath,
31  std::map<unsigned int, QuantizationInput>& passIdToQuantizationInput)
32 {
33  auto iterator = passIdToQuantizationInput.find(passId);
34  if (iterator == passIdToQuantizationInput.end())
35  {
36  QuantizationInput input(passId, bindingId, inputFilePath);
37  passIdToQuantizationInput.emplace(passId, input);
38  }
39  else
40  {
41  auto existingQuantizationInput = iterator->second;
42  existingQuantizationInput.AddEntry(bindingId, inputFilePath);
43  }
44 }
45 
47 {
48 }
49 
52  const char* name)
53 {
54  armnn::IgnoreUnused(name);
55  m_TensorInfos.emplace(id, layer->GetOutputSlot(0).GetTensorInfo());
56 }
57 
59 {
60  auto iterator = m_TensorInfos.find(layerBindingId);
61  if (iterator != m_TensorInfos.end())
62  {
63  return m_TensorInfos.at(layerBindingId);
64  }
65  else
66  {
67  throw armnn::Exception("Could not retrieve tensor info for binding ID " + std::to_string(layerBindingId));
68  }
69 }
70 
71 
72 unsigned int GetPassIdFromCsvRow(std::vector<std::string> tokens, unsigned int lineIndex)
73 {
74  unsigned int passId;
75  try
76  {
77  passId = static_cast<unsigned int>(std::stoi(tokens[0]));
78  }
79  catch (const std::invalid_argument&)
80  {
81  throw armnn::ParseException(fmt::format("Pass ID [{}] is not correct format on CSV row {}",
82  tokens[0], lineIndex));
83  }
84  return passId;
85 }
86 
87 armnn::LayerBindingId GetBindingIdFromCsvRow(std::vector<std::string> tokens, unsigned int lineIndex)
88 {
89  armnn::LayerBindingId bindingId;
90  try
91  {
92  bindingId = std::stoi(tokens[1]);
93  }
94  catch (const std::invalid_argument&)
95  {
96  throw armnn::ParseException(fmt::format("Binding ID [{}] is not correct format on CSV row {}",
97  tokens[1], lineIndex));
98  }
99  return bindingId;
100 }
101 
102 std::string GetFileNameFromCsvRow(std::vector<std::string> tokens, unsigned int lineIndex)
103 {
104  std::string fileName = armnn::stringUtils::StringTrim(tokens[2]);
105 
106  if (!fs::exists(fileName))
107  {
108  throw armnn::ParseException(fmt::format("File [{}] provided on CSV row {} does not exist.",
109  fileName, lineIndex));
110  }
111 
112  if (fileName.empty())
113  {
114  throw armnn::ParseException(fmt::format("Filename cannot be empty on CSV row {} ", lineIndex));
115  }
116  return fileName;
117 }
118 
119 
120 void QuantizationDataSet::ParseCsvFile()
121 {
122  std::map<unsigned int, QuantizationInput> passIdToQuantizationInput;
123 
124  if (m_CsvFilePath == "")
125  {
126  throw armnn::Exception("CSV file not specified.");
127  }
128 
129  std::ifstream inf (m_CsvFilePath.c_str());
130  std::string line;
131  std::vector<std::string> tokens;
132  unsigned int lineIndex = 0;
133 
134  if (!inf)
135  {
136  throw armnn::Exception(fmt::format("CSV file {} not found.", m_CsvFilePath));
137  }
138 
139  while (getline(inf, line))
140  {
141  tokens = armnn::stringUtils::StringTokenizer(line, ",");
142 
143  if (tokens.size() != 3)
144  {
145  throw armnn::Exception(fmt::format("CSV file [{}] does not have correct number of entries" \
146  "on line {}. Expected 3 entries but was {}.",
147  m_CsvFilePath, lineIndex, tokens.size()));
148 
149  }
150 
151  unsigned int passId = GetPassIdFromCsvRow(tokens, lineIndex);
152  armnn::LayerBindingId bindingId = GetBindingIdFromCsvRow(tokens, lineIndex);
153  std::string rawFileName = GetFileNameFromCsvRow(tokens, lineIndex);
154 
155  AddInputData(passId, bindingId, rawFileName, passIdToQuantizationInput);
156 
157  ++lineIndex;
158  }
159 
160  if (passIdToQuantizationInput.empty())
161  {
162  throw armnn::Exception("Could not parse CSV file.");
163  }
164 
165  // Once all entries in CSV file are parsed successfully and QuantizationInput map is populated, populate
166  // QuantizationInputs iterator for easier access and clear the map
167  for (auto itr = passIdToQuantizationInput.begin(); itr != passIdToQuantizationInput.end(); ++itr)
168  {
169  m_QuantizationInputs.emplace_back(itr->second);
170  }
171 }
172 
173 }
std::vector< std::string > StringTokenizer(const std::string &str, const char *delimiters, bool tokenCompression=true)
Function to take a string and a list of delimiters and split the string into tokens based on those de...
Definition: StringUtils.hpp:20
void AddInputData(unsigned int passId, armnn::LayerBindingId bindingId, const std::string &inputFilePath, std::map< unsigned int, QuantizationInput > &passIdToQuantizationInput)
Interface for a layer that is connectable to other layers via InputSlots and OutputSlots.
Definition: INetwork.hpp:61
std::string GetFileNameFromCsvRow(std::vector< std::string > tokens, unsigned int lineIndex)
void VisitInputLayer(const armnn::IConnectableLayer *layer, armnn::LayerBindingId id, const char *name)
Function that an InputLayer should call back to when its Accept(ILayerVisitor&) function is invoked...
QuantizationInputs::iterator iterator
armnn::LayerBindingId GetBindingIdFromCsvRow(std::vector< std::string > tokens, unsigned int lineIndex)
QuantizationInput for specific pass ID, can list a corresponding raw data file for each LayerBindingI...
void IgnoreUnused(Ts &&...)
int LayerBindingId
Type of identifiers for bindable layers (inputs, outputs).
Definition: Types.hpp:202
armnn::TensorInfo GetTensorInfo(armnn::LayerBindingId)
unsigned int GetPassIdFromCsvRow(std::vector< std::string > tokens, unsigned int lineIndex)
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46
virtual const TensorInfo & GetTensorInfo() const =0
virtual const IOutputSlot & GetOutputSlot(unsigned int index) const =0
Get the const output slot handle by slot index.
std::string & StringTrim(std::string &str, const std::string &chars="\\\")
Trim from both the start and the end of a string.
Definition: StringUtils.hpp:77