aboutsummaryrefslogtreecommitdiff
path: root/src/armnnQuantizer/QuantizationDataSet.cpp
diff options
context:
space:
mode:
authorNina Drozd <nina.drozd@arm.com>2019-04-25 15:45:20 +0100
committerNina Drozd <nina.drozd@arm.com>2019-05-03 14:43:50 +0100
commit59e15b00ea51fee4baeea750dc11ab1952dfab1d (patch)
tree97e6c9230bf153d404ad3c3e0e285acdb0b6232d /src/armnnQuantizer/QuantizationDataSet.cpp
parent8b194fbe79d44cba566ad8b508d1c8902987ae3c (diff)
downloadarmnn-59e15b00ea51fee4baeea750dc11ab1952dfab1d.tar.gz
IVGCVSW-2834 Add dynamic quantization via datasets
* Add QuantizationDataSet class for quantization data parsed from CSV file * Add QuantizationInput for retrieving quantization data for each layer ID * Add unit tests for command line processor and QuantizationDataSet Change-Id: Iaf0a747b5f25a59a766ac04f7158e8cb7909d179 Signed-off-by: Nina Drozd <nina.drozd@arm.com>
Diffstat (limited to 'src/armnnQuantizer/QuantizationDataSet.cpp')
-rw-r--r--src/armnnQuantizer/QuantizationDataSet.cpp165
1 files changed, 165 insertions, 0 deletions
diff --git a/src/armnnQuantizer/QuantizationDataSet.cpp b/src/armnnQuantizer/QuantizationDataSet.cpp
new file mode 100644
index 0000000000..d225883854
--- /dev/null
+++ b/src/armnnQuantizer/QuantizationDataSet.cpp
@@ -0,0 +1,165 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "QuantizationDataSet.hpp"
+#include "CsvReader.hpp"
+
+#define BOOST_FILESYSTEM_NO_DEPRECATED
+
+#include <boost/filesystem/operations.hpp>
+#include <boost/filesystem/path.hpp>
+
+namespace armnnQuantizer
+{
+
+QuantizationDataSet::QuantizationDataSet()
+{
+}
+
+QuantizationDataSet::QuantizationDataSet(const std::string csvFilePath):
+ m_QuantizationInputs(),
+ m_CsvFilePath(csvFilePath)
+{
+ ParseCsvFile();
+}
+
+void AddInputData(unsigned int passId,
+ armnn::LayerBindingId bindingId,
+ const std::string& inputFilePath,
+ std::map<unsigned int, QuantizationInput>& passIdToQuantizationInput)
+{
+ auto iterator = passIdToQuantizationInput.find(passId);
+ if (iterator == passIdToQuantizationInput.end())
+ {
+ QuantizationInput input(passId, bindingId, inputFilePath);
+ passIdToQuantizationInput.emplace(passId, input);
+ }
+ else
+ {
+ auto existingQuantizationInput = iterator->second;
+ existingQuantizationInput.AddEntry(bindingId, inputFilePath);
+ }
+}
+
+QuantizationDataSet::~QuantizationDataSet()
+{
+}
+
+void InputLayerVisitor::VisitInputLayer(const armnn::IConnectableLayer* layer,
+ armnn::LayerBindingId id,
+ const char* name)
+{
+ m_TensorInfos.emplace(id, layer->GetInputSlot(0).GetConnection()->GetTensorInfo());
+}
+
+armnn::TensorInfo InputLayerVisitor::GetTensorInfo(armnn::LayerBindingId layerBindingId)
+{
+ auto iterator = m_TensorInfos.find(layerBindingId);
+ if (iterator != m_TensorInfos.end())
+ {
+ return m_TensorInfos.at(layerBindingId);
+ }
+ else
+ {
+ throw armnn::Exception("Could not retrieve tensor info for binding ID " + std::to_string(layerBindingId));
+ }
+}
+
+
+unsigned int GetPassIdFromCsvRow(std::vector<armnnUtils::CsvRow> csvRows, unsigned int rowIndex)
+{
+ unsigned int passId;
+ try
+ {
+ passId = static_cast<unsigned int>(std::stoi(csvRows[rowIndex].values[0]));
+ }
+ catch (std::invalid_argument)
+ {
+ throw armnn::ParseException("Pass ID [" + csvRows[rowIndex].values[0] + "]" +
+ " is not correct format on CSV row " + std::to_string(rowIndex));
+ }
+ return passId;
+}
+
+armnn::LayerBindingId GetBindingIdFromCsvRow(std::vector<armnnUtils::CsvRow> csvRows, unsigned int rowIndex)
+{
+ armnn::LayerBindingId bindingId;
+ try
+ {
+ bindingId = std::stoi(csvRows[rowIndex].values[1]);
+ }
+ catch (std::invalid_argument)
+ {
+ throw armnn::ParseException("Binding ID [" + csvRows[rowIndex].values[0] + "]" +
+ " is not correct format on CSV row " + std::to_string(rowIndex));
+ }
+ return bindingId;
+}
+
+std::string GetFileNameFromCsvRow(std::vector<armnnUtils::CsvRow> csvRows, unsigned int rowIndex)
+{
+ std::string fileName = csvRows[rowIndex].values[2];
+
+ if (!boost::filesystem::exists(fileName))
+ {
+ throw armnn::ParseException("File [ " + fileName + "] provided on CSV row " + std::to_string(rowIndex) +
+ " does not exist.");
+ }
+
+ if (fileName.empty())
+ {
+ throw armnn::ParseException("Filename cannot be empty on CSV row " + std::to_string(rowIndex));
+ }
+ return fileName;
+}
+
+
+void QuantizationDataSet::ParseCsvFile()
+{
+ std::map<unsigned int, QuantizationInput> passIdToQuantizationInput;
+ armnnUtils::CsvReader reader;
+
+ if (m_CsvFilePath == "")
+ {
+ throw armnn::Exception("CSV file not specified.");
+ }
+
+ // Parse CSV file and extract data
+ std::vector<armnnUtils::CsvRow> csvRows = reader.ParseFile(m_CsvFilePath);
+ if (csvRows.empty())
+ {
+ throw armnn::Exception("CSV file [" + m_CsvFilePath + "] is empty.");
+ }
+
+ for (unsigned int i = 0; i < csvRows.size(); ++i)
+ {
+ if (csvRows[i].values.size() != 3)
+ {
+ throw armnn::Exception("CSV file [" + m_CsvFilePath + "] does not have correct number of entries " +
+ "on line " + std::to_string(i) + ". Expected 3 entries " +
+ "but was " + std::to_string(csvRows[i].values.size()));
+ }
+
+ unsigned int passId = GetPassIdFromCsvRow(csvRows, i);
+ armnn::LayerBindingId bindingId = GetBindingIdFromCsvRow(csvRows, i);
+ std::string rawFileName = GetFileNameFromCsvRow(csvRows, i);
+
+ AddInputData(passId, bindingId, rawFileName, passIdToQuantizationInput);
+ }
+
+ if (passIdToQuantizationInput.empty())
+ {
+ throw armnn::Exception("Could not parse CSV file.");
+ }
+
+ // Once all entries in CSV file are parsed successfully and QuantizationInput map is populated, populate
+ // QuantizationInputs iterator for easier access and clear the map
+ for (auto itr = passIdToQuantizationInput.begin(); itr != passIdToQuantizationInput.end(); ++itr)
+ {
+ m_QuantizationInputs.emplace_back(itr->second);
+ }
+}
+
+}