aboutsummaryrefslogtreecommitdiff
path: root/src/armnnQuantizer/ArmNNQuantizerMain.cpp
diff options
context:
space:
mode:
authorNina Drozd <nina.drozd@arm.com>2019-04-25 15:45:20 +0100
committerNina Drozd <nina.drozd@arm.com>2019-05-03 14:43:50 +0100
commit59e15b00ea51fee4baeea750dc11ab1952dfab1d (patch)
tree97e6c9230bf153d404ad3c3e0e285acdb0b6232d /src/armnnQuantizer/ArmNNQuantizerMain.cpp
parent8b194fbe79d44cba566ad8b508d1c8902987ae3c (diff)
downloadarmnn-59e15b00ea51fee4baeea750dc11ab1952dfab1d.tar.gz
IVGCVSW-2834 Add dynamic quantization via datasets
* Add QuantizationDataSet class for quantization data parsed from CSV file * Add QuantizationInput for retrieving quantization data for each layer ID * Add unit tests for command line processor and QuantizationDataSet Change-Id: Iaf0a747b5f25a59a766ac04f7158e8cb7909d179 Signed-off-by: Nina Drozd <nina.drozd@arm.com>
Diffstat (limited to 'src/armnnQuantizer/ArmNNQuantizerMain.cpp')
-rw-r--r--src/armnnQuantizer/ArmNNQuantizerMain.cpp43
1 files changed, 23 insertions, 20 deletions
diff --git a/src/armnnQuantizer/ArmNNQuantizerMain.cpp b/src/armnnQuantizer/ArmNNQuantizerMain.cpp
index 9ac8966753..103597a72d 100644
--- a/src/armnnQuantizer/ArmNNQuantizerMain.cpp
+++ b/src/armnnQuantizer/ArmNNQuantizerMain.cpp
@@ -7,6 +7,8 @@
#include <armnnDeserializer/IDeserializer.hpp>
#include <armnn/INetworkQuantizer.hpp>
#include <armnnSerializer/ISerializer.hpp>
+#include "QuantizationDataSet.hpp"
+#include "QuantizationInput.hpp"
#include <algorithm>
#include <fstream>
@@ -41,31 +43,32 @@ int main(int argc, char* argv[])
armnn::INetworkPtr network = parser->CreateNetworkFromBinary(binaryContent);
armnn::INetworkQuantizerPtr quantizer = armnn::INetworkQuantizer::Create(network.get(), quantizerOptions);
- std::string csvFileName = cmdline.GetCsvFileName();
- if (csvFileName != "")
+ if (cmdline.HasQuantizationData())
{
- // Call the Quantizer::Refine() function which will update the min/max ranges for the quantize constants
- std::ifstream csvFileStream(csvFileName);
- std::string line;
- std::string csvDirectory = cmdline.GetCsvFileDirectory();
- while(getline(csvFileStream, line))
+ armnnQuantizer::QuantizationDataSet dataSet = cmdline.GetQuantizationDataSet();
+ if (!dataSet.IsEmpty())
{
- std::istringstream s(line);
- std::vector<std::string> row;
- std::string entry;
- while(getline(s, entry, ','))
+ // Get the Input Tensor Infos
+ armnnQuantizer::InputLayerVisitor inputLayerVisitor;
+ network->Accept(inputLayerVisitor);
+
+ for(armnnQuantizer::QuantizationInput quantizationInput : dataSet)
{
- entry.erase(std::remove(entry.begin(), entry.end(), ' '), entry.end());
- entry.erase(std::remove(entry.begin(), entry.end(), '"'), entry.end());
- row.push_back(entry);
+ armnn::InputTensors inputTensors;
+ std::vector<std::vector<float>> inputData(quantizationInput.GetNumberOfInputs());
+ std::vector<armnn::LayerBindingId> layerBindingIds = quantizationInput.GetLayerBindingIds();
+ unsigned int count = 0;
+ for (armnn::LayerBindingId layerBindingId : quantizationInput.GetLayerBindingIds())
+ {
+ armnn::TensorInfo tensorInfo = inputLayerVisitor.GetTensorInfo(layerBindingId);
+ inputData[count] = quantizationInput.GetDataForEntry(layerBindingId);
+ armnn::ConstTensor inputTensor(tensorInfo, inputData[count].data());
+ inputTensors.push_back(std::make_pair(layerBindingId, inputTensor));
+ count++;
+ }
+ quantizer->Refine(inputTensors);
}
- std::string rawFileName = cmdline.GetCsvFileDirectory() + "/" + row[2];
- // passId: row[0]
- // bindingId: row[1]
- // rawFileName: file contains the RAW input tensor data
- // LATER: Quantizer::Refine() function will be called with those arguments when it is implemented
}
- csvFileStream.close();
}
armnn::INetworkPtr quantizedNetwork = quantizer->ExportNetwork();