diff options
Diffstat (limited to 'src/armnnQuantizer/ArmNNQuantizerMain.cpp')
-rw-r--r-- | src/armnnQuantizer/ArmNNQuantizerMain.cpp | 43 |
1 files changed, 23 insertions, 20 deletions
diff --git a/src/armnnQuantizer/ArmNNQuantizerMain.cpp b/src/armnnQuantizer/ArmNNQuantizerMain.cpp index 9ac8966753..103597a72d 100644 --- a/src/armnnQuantizer/ArmNNQuantizerMain.cpp +++ b/src/armnnQuantizer/ArmNNQuantizerMain.cpp @@ -7,6 +7,8 @@ #include <armnnDeserializer/IDeserializer.hpp> #include <armnn/INetworkQuantizer.hpp> #include <armnnSerializer/ISerializer.hpp> +#include "QuantizationDataSet.hpp" +#include "QuantizationInput.hpp" #include <algorithm> #include <fstream> @@ -41,31 +43,32 @@ int main(int argc, char* argv[]) armnn::INetworkPtr network = parser->CreateNetworkFromBinary(binaryContent); armnn::INetworkQuantizerPtr quantizer = armnn::INetworkQuantizer::Create(network.get(), quantizerOptions); - std::string csvFileName = cmdline.GetCsvFileName(); - if (csvFileName != "") + if (cmdline.HasQuantizationData()) { - // Call the Quantizer::Refine() function which will update the min/max ranges for the quantize constants - std::ifstream csvFileStream(csvFileName); - std::string line; - std::string csvDirectory = cmdline.GetCsvFileDirectory(); - while(getline(csvFileStream, line)) + armnnQuantizer::QuantizationDataSet dataSet = cmdline.GetQuantizationDataSet(); + if (!dataSet.IsEmpty()) { - std::istringstream s(line); - std::vector<std::string> row; - std::string entry; - while(getline(s, entry, ',')) + // Get the Input Tensor Infos + armnnQuantizer::InputLayerVisitor inputLayerVisitor; + network->Accept(inputLayerVisitor); + + for(armnnQuantizer::QuantizationInput quantizationInput : dataSet) { - entry.erase(std::remove(entry.begin(), entry.end(), ' '), entry.end()); - entry.erase(std::remove(entry.begin(), entry.end(), '"'), entry.end()); - row.push_back(entry); + armnn::InputTensors inputTensors; + std::vector<std::vector<float>> inputData(quantizationInput.GetNumberOfInputs()); + std::vector<armnn::LayerBindingId> layerBindingIds = quantizationInput.GetLayerBindingIds(); + unsigned int count = 0; + for (armnn::LayerBindingId layerBindingId : quantizationInput.GetLayerBindingIds()) + { + armnn::TensorInfo tensorInfo = inputLayerVisitor.GetTensorInfo(layerBindingId); + inputData[count] = quantizationInput.GetDataForEntry(layerBindingId); + armnn::ConstTensor inputTensor(tensorInfo, inputData[count].data()); + inputTensors.push_back(std::make_pair(layerBindingId, inputTensor)); + count++; + } + quantizer->Refine(inputTensors); } - std::string rawFileName = cmdline.GetCsvFileDirectory() + "/" + row[2]; - // passId: row[0] - // bindingId: row[1] - // rawFileName: file contains the RAW input tensor data - // LATER: Quantizer::Refine() function will be called with those arguments when it is implemented } - csvFileStream.close(); } armnn::INetworkPtr quantizedNetwork = quantizer->ExportNetwork(); |