// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // #include "ImageTensorGenerator.hpp" #include "../InferenceTestImage.hpp" #include #include #include #include #include #include #include #include namespace { // parses the command line to extract // * the input image file -i the input image file path (must exist) // * the layout -l the data layout output generated with (optional - default value is NHWC) // * the output file -o the output raw tensor file path (must not already exist) class CommandLineProcessor { public: bool ParseOptions(cxxopts::ParseResult& result) { // infile is mandatory if (result.count("infile")) { if (!ValidateInputFile(result["infile"].as())) { return false; } } else { std::cerr << "-i/--infile parameter is mandatory." << std::endl; return false; } // model-format is mandatory if (!result.count("model-format")) { std::cerr << "-f/--model-format parameter is mandatory." << std::endl; return false; } // outfile is mandatory if (result.count("outfile")) { if (!ValidateOutputFile(result["outfile"].as())) { return false; } } else { std::cerr << "-o/--outfile parameter is mandatory." << std::endl; return false; } if (result.count("layout")) { if(!ValidateLayout(result["layout"].as())) { return false; } } return true; } bool ValidateInputFile(const std::string& inputFileName) { if (inputFileName.empty()) { std::cerr << "No input file name specified" << std::endl; return false; } if (!fs::exists(inputFileName)) { std::cerr << "Input file [" << inputFileName << "] does not exist" << std::endl; return false; } if (fs::is_directory(inputFileName)) { std::cerr << "Input file [" << inputFileName << "] is a directory" << std::endl; return false; } return true; } bool ValidateLayout(const std::string& layout) { if (layout.empty()) { std::cerr << "No layout specified" << std::endl; return false; } std::vector supportedLayouts = { "NHWC", "NCHW" }; auto iterator = std::find(supportedLayouts.begin(), supportedLayouts.end(), layout); if (iterator == supportedLayouts.end()) { std::cerr << "Layout [" << layout << "] is not supported" << std::endl; return false; } return true; } bool ValidateOutputFile(const std::string& outputFileName) { if (outputFileName.empty()) { std::cerr << "No output file name specified" << std::endl; return false; } if (fs::exists(outputFileName)) { std::cerr << "Output file [" << outputFileName << "] already exists" << std::endl; return false; } if (fs::is_directory(outputFileName)) { std::cerr << "Output file [" << outputFileName << "] is a directory" << std::endl; return false; } fs::path outputPath(outputFileName); if (!fs::exists(outputPath.parent_path())) { std::cerr << "Output directory [" << outputPath.parent_path().c_str() << "] does not exist" << std::endl; return false; } return true; } bool ProcessCommandLine(int argc, char* argv[]) { cxxopts::Options options("ImageTensorGenerator", "Program for pre-processing a .jpg image " "before generating a .raw tensor file from it."); try { options.add_options() ("h,help", "Display help messages") ("i,infile", "Input image file to generate tensor from", cxxopts::value(m_InputFileName)) ("f,model-format", "Format of the intended model file that uses the images." "Different formats have different image normalization styles." "If unset, defaults to tflite." "Accepted value (tflite)", cxxopts::value(m_ModelFormat)->default_value("tflite")) ("o,outfile", "Output raw tensor file path", cxxopts::value(m_OutputFileName)) ("z,output-type", "The data type of the output tensors." "If unset, defaults to \"float\" for all defined inputs. " "Accepted values (float, int, qasymms8 or qasymmu8)", cxxopts::value(m_OutputType)->default_value("float")) ("new-width", "Resize image to new width. Keep original width if unspecified", cxxopts::value(m_NewWidth)->default_value("0")) ("new-height", "Resize image to new height. Keep original height if unspecified", cxxopts::value(m_NewHeight)->default_value("0")) ("l,layout", "Output data layout, \"NHWC\" or \"NCHW\", default value NHWC", cxxopts::value(m_Layout)->default_value("NHWC")); } catch (const std::exception& e) { std::cerr << options.help() << std::endl; return false; } try { auto result = options.parse(argc, argv); if (result.count("help")) { std::cout << options.help() << std::endl; return false; } // Check for mandatory parameters and validate inputs if(!ParseOptions(result)){ return false; } } catch (const cxxopts::OptionException& e) { std::cerr << e.what() << std::endl << std::endl; return false; } return true; } std::string GetInputFileName() {return m_InputFileName;} armnn::DataLayout GetLayout() { if (m_Layout == "NHWC") { return armnn::DataLayout::NHWC; } else if (m_Layout == "NCHW") { return armnn::DataLayout::NCHW; } else { throw armnn::Exception("Unsupported data layout: " + m_Layout); } } std::string GetOutputFileName() {return m_OutputFileName;} unsigned int GetNewWidth() {return static_cast(std::stoi(m_NewWidth));} unsigned int GetNewHeight() {return static_cast(std::stoi(m_NewHeight));} SupportedFrontend GetModelFormat() { if (m_ModelFormat == "tflite") { return SupportedFrontend::TFLite; } else { throw armnn::Exception("Unsupported model format" + m_ModelFormat); } } armnn::DataType GetOutputType() { if (m_OutputType == "float") { return armnn::DataType::Float32; } else if (m_OutputType == "int") { return armnn::DataType::Signed32; } else if (m_OutputType == "qasymm8" || m_OutputType == "qasymmu8") { return armnn::DataType::QAsymmU8; } else if (m_OutputType == "qasymms8") { return armnn::DataType::QAsymmS8; } else { throw armnn::Exception("Unsupported input type" + m_OutputType); } } private: std::string m_InputFileName; std::string m_Layout; std::string m_OutputFileName; std::string m_NewWidth; std::string m_NewHeight; std::string m_ModelFormat; std::string m_OutputType; }; } // namespace anonymous int main(int argc, char* argv[]) { CommandLineProcessor cmdline; if (!cmdline.ProcessCommandLine(argc, argv)) { return -1; } const std::string imagePath(cmdline.GetInputFileName()); const std::string outputPath(cmdline.GetOutputFileName()); const SupportedFrontend& modelFormat(cmdline.GetModelFormat()); const armnn::DataType outputType(cmdline.GetOutputType()); const unsigned int newWidth = cmdline.GetNewWidth(); const unsigned int newHeight = cmdline.GetNewHeight(); const unsigned int batchSize = 1; const armnn::DataLayout outputLayout(cmdline.GetLayout()); std::vector imageDataContainers; const NormalizationParameters& normParams = GetNormalizationParameters(modelFormat, outputType); try { switch (outputType) { case armnn::DataType::Signed32: imageDataContainers.push_back(PrepareImageTensor( imagePath, newWidth, newHeight, normParams, batchSize, outputLayout)); break; case armnn::DataType::QAsymmU8: imageDataContainers.push_back(PrepareImageTensor( imagePath, newWidth, newHeight, normParams, batchSize, outputLayout)); break; case armnn::DataType::QAsymmS8: imageDataContainers.push_back(PrepareImageTensor( imagePath, newWidth, newHeight, normParams, batchSize, outputLayout)); break; case armnn::DataType::Float32: default: imageDataContainers.push_back(PrepareImageTensor( imagePath, newWidth, newHeight, normParams, batchSize, outputLayout)); break; } } catch (const InferenceTestImageException& e) { ARMNN_LOG(fatal) << "Failed to load image file " << imagePath << " with error: " << e.what(); return -1; } std::ofstream imageTensorFile; imageTensorFile.open(outputPath, std::ofstream::out); if (imageTensorFile.is_open()) { mapbox::util::apply_visitor( [&imageTensorFile](auto&& imageData){ WriteImageTensorImpl(imageData,imageTensorFile); }, imageDataContainers[0] ); if (!imageTensorFile) { ARMNN_LOG(fatal) << "Failed to write to output file" << outputPath; imageTensorFile.close(); return -1; } imageTensorFile.close(); } else { ARMNN_LOG(fatal) << "Failed to open output file" << outputPath; return -1; } return 0; }