aboutsummaryrefslogtreecommitdiff
path: root/src/armnnQuantizer/ArmNNQuantizerMain.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnQuantizer/ArmNNQuantizerMain.cpp')
-rw-r--r--src/armnnQuantizer/ArmNNQuantizerMain.cpp43
1 files changed, 23 insertions, 20 deletions
diff --git a/src/armnnQuantizer/ArmNNQuantizerMain.cpp b/src/armnnQuantizer/ArmNNQuantizerMain.cpp
index 9ac8966753..103597a72d 100644
--- a/src/armnnQuantizer/ArmNNQuantizerMain.cpp
+++ b/src/armnnQuantizer/ArmNNQuantizerMain.cpp
@@ -7,6 +7,8 @@
#include <armnnDeserializer/IDeserializer.hpp>
#include <armnn/INetworkQuantizer.hpp>
#include <armnnSerializer/ISerializer.hpp>
+#include "QuantizationDataSet.hpp"
+#include "QuantizationInput.hpp"
#include <algorithm>
#include <fstream>
@@ -41,31 +43,32 @@ int main(int argc, char* argv[])
armnn::INetworkPtr network = parser->CreateNetworkFromBinary(binaryContent);
armnn::INetworkQuantizerPtr quantizer = armnn::INetworkQuantizer::Create(network.get(), quantizerOptions);
- std::string csvFileName = cmdline.GetCsvFileName();
- if (csvFileName != "")
+ if (cmdline.HasQuantizationData())
{
- // Call the Quantizer::Refine() function which will update the min/max ranges for the quantize constants
- std::ifstream csvFileStream(csvFileName);
- std::string line;
- std::string csvDirectory = cmdline.GetCsvFileDirectory();
- while(getline(csvFileStream, line))
+ armnnQuantizer::QuantizationDataSet dataSet = cmdline.GetQuantizationDataSet();
+ if (!dataSet.IsEmpty())
{
- std::istringstream s(line);
- std::vector<std::string> row;
- std::string entry;
- while(getline(s, entry, ','))
+ // Get the Input Tensor Infos
+ armnnQuantizer::InputLayerVisitor inputLayerVisitor;
+ network->Accept(inputLayerVisitor);
+
+ for(armnnQuantizer::QuantizationInput quantizationInput : dataSet)
{
- entry.erase(std::remove(entry.begin(), entry.end(), ' '), entry.end());
- entry.erase(std::remove(entry.begin(), entry.end(), '"'), entry.end());
- row.push_back(entry);
+ armnn::InputTensors inputTensors;
+ std::vector<std::vector<float>> inputData(quantizationInput.GetNumberOfInputs());
+ std::vector<armnn::LayerBindingId> layerBindingIds = quantizationInput.GetLayerBindingIds();
+ unsigned int count = 0;
+ for (armnn::LayerBindingId layerBindingId : quantizationInput.GetLayerBindingIds())
+ {
+ armnn::TensorInfo tensorInfo = inputLayerVisitor.GetTensorInfo(layerBindingId);
+ inputData[count] = quantizationInput.GetDataForEntry(layerBindingId);
+ armnn::ConstTensor inputTensor(tensorInfo, inputData[count].data());
+ inputTensors.push_back(std::make_pair(layerBindingId, inputTensor));
+ count++;
+ }
+ quantizer->Refine(inputTensors);
}
- std::string rawFileName = cmdline.GetCsvFileDirectory() + "/" + row[2];
- // passId: row[0]
- // bindingId: row[1]
- // rawFileName: file contains the RAW input tensor data
- // LATER: Quantizer::Refine() function will be called with those arguments when it is implemented
}
- csvFileStream.close();
}
armnn::INetworkPtr quantizedNetwork = quantizer->ExportNetwork();