aboutsummaryrefslogtreecommitdiff
path: root/src/armnnQuantizer
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2019-04-12 15:17:02 +0100
committerSadik Armagan <sadik.armagan@arm.com>2019-04-12 15:17:02 +0100
commit2b03d64da2a3a39d069a7a2366f14439afb1ad39 (patch)
tree7a6ed87ad3c9feeea779fa4e241d6f589cb0dc56 /src/armnnQuantizer
parent7b4886faccb52af9afe7fdeffcbae87e7fbc1484 (diff)
downloadarmnn-2b03d64da2a3a39d069a7a2366f14439afb1ad39.tar.gz
IVGCVSW-2955 Update the Quantizer Tool to take an additional parameter for the user to specify a CSV file
Change-Id: Id56e09f147cca5c1301ec1b6bac656cd50bfd583 Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Diffstat (limited to 'src/armnnQuantizer')
-rw-r--r--src/armnnQuantizer/ArmNNQuantizerMain.cpp33
-rw-r--r--src/armnnQuantizer/CommandLineProcessor.cpp23
-rw-r--r--src/armnnQuantizer/CommandLineProcessor.hpp5
3 files changed, 55 insertions, 6 deletions
diff --git a/src/armnnQuantizer/ArmNNQuantizerMain.cpp b/src/armnnQuantizer/ArmNNQuantizerMain.cpp
index acfbe6241f..d7fc932327 100644
--- a/src/armnnQuantizer/ArmNNQuantizerMain.cpp
+++ b/src/armnnQuantizer/ArmNNQuantizerMain.cpp
@@ -8,8 +8,9 @@
#include <armnn/INetworkQuantizer.hpp>
#include <armnnSerializer/ISerializer.hpp>
-#include <iostream>
+#include <algorithm>
#include <fstream>
+#include <iostream>
int main(int argc, char* argv[])
{
@@ -32,8 +33,36 @@ int main(int argc, char* argv[])
}
inputFileStream.close();
armnn::INetworkPtr network = parser->CreateNetworkFromBinary(binaryContent);
- armnn::INetworkPtr quantizedNetwork = armnn::INetworkQuantizer::Create(network.get())->ExportNetwork();
+ armnn::INetworkQuantizerPtr quantizer = armnn::INetworkQuantizer::Create(network.get());
+
+ std::string csvFileName = cmdline.GetCsvFileName();
+ if (csvFileName != "")
+ {
+ // Call the Quantizer::Refine() function which will update the min/max ranges for the quantize constants
+ std::ifstream csvFileStream(csvFileName);
+ std::string line;
+ std::string csvDirectory = cmdline.GetCsvFileDirectory();
+ while(getline(csvFileStream, line))
+ {
+ std::istringstream s(line);
+ std::vector<std::string> row;
+ std::string entry;
+ while(getline(s, entry, ','))
+ {
+ entry.erase(std::remove(entry.begin(), entry.end(), ' '), entry.end());
+ entry.erase(std::remove(entry.begin(), entry.end(), '"'), entry.end());
+ row.push_back(entry);
+ }
+ std::string rawFileName = cmdline.GetCsvFileDirectory() + "/" + row[2];
+ // passId: row[0]
+ // bindingId: row[1]
+ // rawFileName: file contains the RAW input tensor data
+ // LATER: Quantizer::Refine() function will be called with those arguments when it is implemented
+ }
+ csvFileStream.close();
+ }
+ armnn::INetworkPtr quantizedNetwork = quantizer->ExportNetwork();
armnnSerializer::ISerializerPtr serializer = armnnSerializer::ISerializer::Create();
serializer->Serialize(*quantizedNetwork);
diff --git a/src/armnnQuantizer/CommandLineProcessor.cpp b/src/armnnQuantizer/CommandLineProcessor.cpp
index 1a10d38cdf..a7baa5cac4 100644
--- a/src/armnnQuantizer/CommandLineProcessor.cpp
+++ b/src/armnnQuantizer/CommandLineProcessor.cpp
@@ -42,17 +42,17 @@ bool ValidateOutputDirectory(std::string& dir)
return true;
}
-bool ValidateInputFile(const std::string& inputFileName)
+bool ValidateProvidedFile(const std::string& inputFileName)
{
if (!boost::filesystem::exists(inputFileName))
{
- std::cerr << "Input file [" << inputFileName << "] does not exist" << std::endl;
+ std::cerr << "Provided file [" << inputFileName << "] does not exist" << std::endl;
return false;
}
if (boost::filesystem::is_directory(inputFileName))
{
- std::cerr << "Given input file [" << inputFileName << "] is a directory" << std::endl;
+ std::cerr << "Given file [" << inputFileName << "] is a directory" << std::endl;
return false;
}
@@ -70,6 +70,8 @@ bool CommandLineProcessor::ProcessCommandLine(int argc, char* argv[])
("help,h", "Display help messages")
("infile,f", po::value<std::string>(&m_InputFileName)->required(),
"Input file containing float 32 ArmNN Input Graph")
+ ("csvfile,c", po::value<std::string>(&m_CsvFileName)->default_value(""),
+ "CSV file containing paths for RAW input tensors")
("outdir,d", po::value<std::string>(&m_OutputDirectory)->required(),
"Directory that output file will be written to")
("outfile,o", po::value<std::string>(&m_OutputFileName)->required(), "Output file name");
@@ -101,11 +103,24 @@ bool CommandLineProcessor::ProcessCommandLine(int argc, char* argv[])
return false;
}
- if (!armnnQuantizer::ValidateInputFile(m_InputFileName))
+ if (!armnnQuantizer::ValidateProvidedFile(m_InputFileName))
{
return false;
}
+ if (m_CsvFileName != "")
+ {
+ if (!armnnQuantizer::ValidateProvidedFile(m_CsvFileName))
+ {
+ return false;
+ }
+ else
+ {
+ boost::filesystem::path csvFilePath(m_CsvFileName);
+ m_CsvFileDirectory = csvFilePath.parent_path().c_str();
+ }
+ }
+
if (!armnnQuantizer::ValidateOutputDirectory(m_OutputDirectory))
{
return false;
diff --git a/src/armnnQuantizer/CommandLineProcessor.hpp b/src/armnnQuantizer/CommandLineProcessor.hpp
index f55e7a213f..852fcd4070 100644
--- a/src/armnnQuantizer/CommandLineProcessor.hpp
+++ b/src/armnnQuantizer/CommandLineProcessor.hpp
@@ -12,6 +12,7 @@ namespace armnnQuantizer
// parses the command line to extract
// * the input file -f containing the serialized fp32 ArmNN input graph (must exist...and be a input graph file)
+// * the csv file -c <optional> detailing the paths for RAW input tensors to use for refinement
// * the directory -d to place the output file into (must already exist and be writable)
// * the name of the file -o the quantized ArmNN input graph will be written to (must not already exist)
// * LATER: the min and max overrides to be applied to the inputs
@@ -25,10 +26,14 @@ public:
bool ProcessCommandLine(int argc, char* argv[]);
std::string GetInputFileName() {return m_InputFileName;}
+ std::string GetCsvFileName() {return m_CsvFileName;}
+ std::string GetCsvFileDirectory() {return m_CsvFileDirectory;}
std::string GetOutputDirectoryName() {return m_OutputDirectory;}
std::string GetOutputFileName() {return m_OutputFileName;}
private:
std::string m_InputFileName;
+ std::string m_CsvFileName;
+ std::string m_CsvFileDirectory;
std::string m_OutputDirectory;
std::string m_OutputFileName;
};