aboutsummaryrefslogtreecommitdiff
path: root/tests/ImageTensorGenerator/ImageTensorGenerator.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/ImageTensorGenerator/ImageTensorGenerator.cpp')
-rw-r--r--tests/ImageTensorGenerator/ImageTensorGenerator.cpp123
1 files changed, 105 insertions, 18 deletions
diff --git a/tests/ImageTensorGenerator/ImageTensorGenerator.cpp b/tests/ImageTensorGenerator/ImageTensorGenerator.cpp
index 1f537745b4..f391a27a4d 100644
--- a/tests/ImageTensorGenerator/ImageTensorGenerator.cpp
+++ b/tests/ImageTensorGenerator/ImageTensorGenerator.cpp
@@ -3,13 +3,16 @@
// SPDX-License-Identifier: MIT
//
+#include "ImageTensorGenerator.hpp"
#include "../InferenceTestImage.hpp"
+#include <armnn/TypesUtils.hpp>
#include <boost/filesystem.hpp>
#include <boost/filesystem/operations.hpp>
#include <boost/filesystem/path.hpp>
#include <boost/log/trivial.hpp>
#include <boost/program_options.hpp>
+#include <boost/variant.hpp>
#include <algorithm>
#include <fstream>
@@ -57,10 +60,7 @@ public:
return false;
}
- std::vector<std::string> supportedLayouts = {
- "NHWC",
- "NCHW"
- };
+ std::vector<std::string> supportedLayouts = { "NHWC", "NCHW" };
auto iterator = std::find(supportedLayouts.begin(), supportedLayouts.end(), layout);
if (iterator == supportedLayouts.end())
@@ -113,10 +113,20 @@ public:
("help,h", "Display help messages")
("infile,i", po::value<std::string>(&m_InputFileName)->required(),
"Input image file to generate tensor from")
- ("layout,l", po::value<std::string>(&m_Layout)->default_value("NHWC"),
- "Output data layout, \"NHWC\" or \"NCHW\", default value NHWC")
+ ("model-format,f", po::value<std::string>(&m_ModelFormat)->required(),
+ "Format of the model file, Accepted values (caffe, tensorflow, tflite)")
("outfile,o", po::value<std::string>(&m_OutputFileName)->required(),
- "Output raw tensor file path");
+ "Output raw tensor file path")
+ ("output-type,z", po::value<std::string>(&m_OutputType)->default_value("float"),
+ "The data type of the output tensors."
+ "If unset, defaults to \"float\" for all defined inputs. "
+ "Accepted values (float, int or qasymm8)")
+ ("new-width,w", po::value<std::string>(&m_NewWidth)->default_value("0"),
+ "Resize image to new width. Keep original width if unspecified")
+ ("new-height,h", po::value<std::string>(&m_NewHeight)->default_value("0"),
+ "Resize image to new height. Keep original height if unspecified")
+ ("layout,l", po::value<std::string>(&m_Layout)->default_value("NHWC"),
+ "Output data layout, \"NHWC\" or \"NCHW\", default value NHWC");
}
catch (const std::exception& e)
{
@@ -164,13 +174,71 @@ public:
}
std::string GetInputFileName() {return m_InputFileName;}
- std::string GetLayout() {return m_Layout;}
+ armnn::DataLayout GetLayout()
+ {
+ if (m_Layout == "NHWC")
+ {
+ return armnn::DataLayout::NHWC;
+ }
+ else if (m_Layout == "NCHW")
+ {
+ return armnn::DataLayout::NCHW;
+ }
+ else
+ {
+ throw armnn::Exception("Unsupported data layout: " + m_Layout);
+ }
+ }
std::string GetOutputFileName() {return m_OutputFileName;}
+ unsigned int GetNewWidth() {return static_cast<unsigned int>(std::stoi(m_NewWidth));}
+ unsigned int GetNewHeight() {return static_cast<unsigned int>(std::stoi(m_NewHeight));}
+ SupportedFrontend GetModelFormat()
+ {
+ if (m_ModelFormat == "caffe")
+ {
+ return SupportedFrontend::Caffe;
+ }
+ else if (m_ModelFormat == "tensorflow")
+ {
+ return SupportedFrontend::TensorFlow;
+ }
+ else if (m_ModelFormat == "tflite")
+ {
+ return SupportedFrontend::TFLite;
+ }
+ else
+ {
+ throw armnn::Exception("Unsupported model format" + m_ModelFormat);
+ }
+ }
+ armnn::DataType GetOutputType()
+ {
+ if (m_OutputType == "float")
+ {
+ return armnn::DataType::Float32;
+ }
+ else if (m_OutputType == "int")
+ {
+ return armnn::DataType::Signed32;
+ }
+ else if (m_OutputType == "qasymm8")
+ {
+ return armnn::DataType::QuantisedAsymm8;
+ }
+ else
+ {
+ throw armnn::Exception("Unsupported input type" + m_OutputType);
+ }
+ }
private:
std::string m_InputFileName;
std::string m_Layout;
std::string m_OutputFileName;
+ std::string m_NewWidth;
+ std::string m_NewHeight;
+ std::string m_ModelFormat;
+ std::string m_OutputType;
};
} // namespace anonymous
@@ -182,18 +250,36 @@ int main(int argc, char* argv[])
{
return -1;
}
-
const std::string imagePath(cmdline.GetInputFileName());
const std::string outputPath(cmdline.GetOutputFileName());
-
- // generate image tensor
- std::vector<float> imageData;
+ const SupportedFrontend& modelFormat(cmdline.GetModelFormat());
+ const armnn::DataType outputType(cmdline.GetOutputType());
+ const unsigned int newWidth = cmdline.GetNewWidth();
+ const unsigned int newHeight = cmdline.GetNewHeight();
+ const unsigned int batchSize = 1;
+ const armnn::DataLayout outputLayout(cmdline.GetLayout());
+
+ using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<uint8_t>>;
+ std::vector<TContainer> imageDataContainers;
+ const NormalizationParameters& normParams = GetNormalizationParameters(modelFormat, outputType);
try
{
- InferenceTestImage testImage(imagePath.c_str());
- imageData = cmdline.GetLayout() == "NHWC"
- ? GetImageDataAsNormalizedFloats(ImageChannelLayout::Rgb, testImage)
- : GetImageDataInArmNnLayoutAsNormalizedFloats(ImageChannelLayout::Rgb, testImage);
+ switch (outputType)
+ {
+ case armnn::DataType::Signed32:
+ imageDataContainers.push_back(PrepareImageTensor<int>(
+ imagePath, newWidth, newHeight, normParams, batchSize, outputLayout));
+ break;
+ case armnn::DataType::QuantisedAsymm8:
+ imageDataContainers.push_back(PrepareImageTensor<uint8_t>(
+ imagePath, newWidth, newHeight, normParams, batchSize, outputLayout));
+ break;
+ case armnn::DataType::Float32:
+ default:
+ imageDataContainers.push_back(PrepareImageTensor<float>(
+ imagePath, newWidth, newHeight, normParams, batchSize, outputLayout));
+ break;
+ }
}
catch (const InferenceTestImageException& e)
{
@@ -205,7 +291,8 @@ int main(int argc, char* argv[])
imageTensorFile.open(outputPath, std::ofstream::out);
if (imageTensorFile.is_open())
{
- std::copy(imageData.begin(), imageData.end(), std::ostream_iterator<float>(imageTensorFile, " "));
+ boost::apply_visitor([&imageTensorFile](auto&& imageData) { WriteImageTensorImpl(imageData, imageTensorFile); },
+ imageDataContainers[0]);
if (!imageTensorFile)
{
BOOST_LOG_TRIVIAL(fatal) << "Failed to write to output file" << outputPath;
@@ -221,4 +308,4 @@ int main(int argc, char* argv[])
}
return 0;
-} \ No newline at end of file
+}