diff options
Diffstat (limited to 'src/armnnQuantizer')
-rw-r--r-- | src/armnnQuantizer/ArmNNQuantizerMain.cpp | 33 | ||||
-rw-r--r-- | src/armnnQuantizer/CommandLineProcessor.cpp | 23 | ||||
-rw-r--r-- | src/armnnQuantizer/CommandLineProcessor.hpp | 5 |
3 files changed, 55 insertions, 6 deletions
diff --git a/src/armnnQuantizer/ArmNNQuantizerMain.cpp b/src/armnnQuantizer/ArmNNQuantizerMain.cpp index acfbe6241f..d7fc932327 100644 --- a/src/armnnQuantizer/ArmNNQuantizerMain.cpp +++ b/src/armnnQuantizer/ArmNNQuantizerMain.cpp @@ -8,8 +8,9 @@ #include <armnn/INetworkQuantizer.hpp> #include <armnnSerializer/ISerializer.hpp> -#include <iostream> +#include <algorithm> #include <fstream> +#include <iostream> int main(int argc, char* argv[]) { @@ -32,8 +33,36 @@ int main(int argc, char* argv[]) } inputFileStream.close(); armnn::INetworkPtr network = parser->CreateNetworkFromBinary(binaryContent); - armnn::INetworkPtr quantizedNetwork = armnn::INetworkQuantizer::Create(network.get())->ExportNetwork(); + armnn::INetworkQuantizerPtr quantizer = armnn::INetworkQuantizer::Create(network.get()); + + std::string csvFileName = cmdline.GetCsvFileName(); + if (csvFileName != "") + { + // Call the Quantizer::Refine() function which will update the min/max ranges for the quantize constants + std::ifstream csvFileStream(csvFileName); + std::string line; + std::string csvDirectory = cmdline.GetCsvFileDirectory(); + while(getline(csvFileStream, line)) + { + std::istringstream s(line); + std::vector<std::string> row; + std::string entry; + while(getline(s, entry, ',')) + { + entry.erase(std::remove(entry.begin(), entry.end(), ' '), entry.end()); + entry.erase(std::remove(entry.begin(), entry.end(), '"'), entry.end()); + row.push_back(entry); + } + std::string rawFileName = cmdline.GetCsvFileDirectory() + "/" + row[2]; + // passId: row[0] + // bindingId: row[1] + // rawFileName: file contains the RAW input tensor data + // LATER: Quantizer::Refine() function will be called with those arguments when it is implemented + } + csvFileStream.close(); + } + armnn::INetworkPtr quantizedNetwork = quantizer->ExportNetwork(); armnnSerializer::ISerializerPtr serializer = armnnSerializer::ISerializer::Create(); serializer->Serialize(*quantizedNetwork); diff --git a/src/armnnQuantizer/CommandLineProcessor.cpp b/src/armnnQuantizer/CommandLineProcessor.cpp index 1a10d38cdf..a7baa5cac4 100644 --- a/src/armnnQuantizer/CommandLineProcessor.cpp +++ b/src/armnnQuantizer/CommandLineProcessor.cpp @@ -42,17 +42,17 @@ bool ValidateOutputDirectory(std::string& dir) return true; } -bool ValidateInputFile(const std::string& inputFileName) +bool ValidateProvidedFile(const std::string& inputFileName) { if (!boost::filesystem::exists(inputFileName)) { - std::cerr << "Input file [" << inputFileName << "] does not exist" << std::endl; + std::cerr << "Provided file [" << inputFileName << "] does not exist" << std::endl; return false; } if (boost::filesystem::is_directory(inputFileName)) { - std::cerr << "Given input file [" << inputFileName << "] is a directory" << std::endl; + std::cerr << "Given file [" << inputFileName << "] is a directory" << std::endl; return false; } @@ -70,6 +70,8 @@ bool CommandLineProcessor::ProcessCommandLine(int argc, char* argv[]) ("help,h", "Display help messages") ("infile,f", po::value<std::string>(&m_InputFileName)->required(), "Input file containing float 32 ArmNN Input Graph") + ("csvfile,c", po::value<std::string>(&m_CsvFileName)->default_value(""), + "CSV file containing paths for RAW input tensors") ("outdir,d", po::value<std::string>(&m_OutputDirectory)->required(), "Directory that output file will be written to") ("outfile,o", po::value<std::string>(&m_OutputFileName)->required(), "Output file name"); @@ -101,11 +103,24 @@ bool CommandLineProcessor::ProcessCommandLine(int argc, char* argv[]) return false; } - if (!armnnQuantizer::ValidateInputFile(m_InputFileName)) + if (!armnnQuantizer::ValidateProvidedFile(m_InputFileName)) { return false; } + if (m_CsvFileName != "") + { + if (!armnnQuantizer::ValidateProvidedFile(m_CsvFileName)) + { + return false; + } + else + { + boost::filesystem::path csvFilePath(m_CsvFileName); + m_CsvFileDirectory = csvFilePath.parent_path().c_str(); + } + } + if (!armnnQuantizer::ValidateOutputDirectory(m_OutputDirectory)) { return false; diff --git a/src/armnnQuantizer/CommandLineProcessor.hpp b/src/armnnQuantizer/CommandLineProcessor.hpp index f55e7a213f..852fcd4070 100644 --- a/src/armnnQuantizer/CommandLineProcessor.hpp +++ b/src/armnnQuantizer/CommandLineProcessor.hpp @@ -12,6 +12,7 @@ namespace armnnQuantizer // parses the command line to extract // * the input file -f containing the serialized fp32 ArmNN input graph (must exist...and be a input graph file) +// * the csv file -c <optional> detailing the paths for RAW input tensors to use for refinement // * the directory -d to place the output file into (must already exist and be writable) // * the name of the file -o the quantized ArmNN input graph will be written to (must not already exist) // * LATER: the min and max overrides to be applied to the inputs @@ -25,10 +26,14 @@ public: bool ProcessCommandLine(int argc, char* argv[]); std::string GetInputFileName() {return m_InputFileName;} + std::string GetCsvFileName() {return m_CsvFileName;} + std::string GetCsvFileDirectory() {return m_CsvFileDirectory;} std::string GetOutputDirectoryName() {return m_OutputDirectory;} std::string GetOutputFileName() {return m_OutputFileName;} private: std::string m_InputFileName; + std::string m_CsvFileName; + std::string m_CsvFileDirectory; std::string m_OutputDirectory; std::string m_OutputFileName; }; |