aboutsummaryrefslogtreecommitdiff
path: root/tests/ModelAccuracyTool-Armnn/ModelAccuracyTool-Armnn.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/ModelAccuracyTool-Armnn/ModelAccuracyTool-Armnn.cpp')
-rw-r--r--tests/ModelAccuracyTool-Armnn/ModelAccuracyTool-Armnn.cpp108
1 files changed, 43 insertions, 65 deletions
diff --git a/tests/ModelAccuracyTool-Armnn/ModelAccuracyTool-Armnn.cpp b/tests/ModelAccuracyTool-Armnn/ModelAccuracyTool-Armnn.cpp
index 3abfb3c2ec..aec4d70271 100644
--- a/tests/ModelAccuracyTool-Armnn/ModelAccuracyTool-Armnn.cpp
+++ b/tests/ModelAccuracyTool-Armnn/ModelAccuracyTool-Armnn.cpp
@@ -4,9 +4,9 @@
//
#include "ModelAccuracyChecker.hpp"
-#include "../InferenceTest.hpp"
#include "../ImagePreprocessor.hpp"
#include "armnnDeserializer/IDeserializer.hpp"
+#include "../NetworkExecutionUtils/NetworkExecutionUtils.hpp"
#include <boost/filesystem.hpp>
#include <boost/range/iterator_range.hpp>
@@ -14,70 +14,8 @@
using namespace armnn::test;
-namespace po = boost::program_options;
-
-bool CheckOption(const po::variables_map& vm,
- const char* option)
-{
- // Check that the given option is valid.
- if (option == nullptr)
- {
- return false;
- }
-
- // Check whether 'option' is provided.
- return vm.find(option) != vm.end();
-}
-
-template<typename T, typename TParseElementFunc>
-std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc, const char * chars = "\t ,:")
-{
- std::vector<T> result;
- // Processes line-by-line.
- std::string line;
- while (std::getline(stream, line))
- {
- std::vector<std::string> tokens;
- try
- {
- // Coverity fix: boost::split() may throw an exception of type boost::bad_function_call.
- boost::split(tokens, line, boost::algorithm::is_any_of(chars), boost::token_compress_on);
- }
- catch (const std::exception& e)
- {
- BOOST_LOG_TRIVIAL(error) << "An error occurred when splitting tokens: " << e.what();
- continue;
- }
- for (const std::string& token : tokens)
- {
- if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
- {
- try
- {
- result.push_back(parseElementFunc(token));
- }
- catch (const std::exception&)
- {
- BOOST_LOG_TRIVIAL(error) << "'" << token << "' is not a valid number. It has been ignored.";
- }
- }
- }
- }
-
- return result;
-}
-
map<std::string, int> LoadValidationLabels(const string & validationLabelPath);
-template<armnn::DataType NonQuantizedType>
-auto ParseDataArray(std::istream & stream);
-
-template<>
-auto ParseDataArray<armnn::DataType::Float32>(std::istream & stream)
-{
- return ParseArrayImpl<float>(stream, [](const std::string& s) { return std::stof(s); });
-}
-
int main(int argc, char* argv[])
{
try
@@ -94,6 +32,7 @@ int main(int argc, char* argv[])
std::vector<armnn::BackendId> defaultBackends = {armnn::Compute::CpuAcc, armnn::Compute::CpuRef};
std::string modelPath;
std::string dataDir;
+ std::string inputType = "float";
std::string inputName;
std::string outputName;
std::string validationLabelPath;
@@ -112,6 +51,9 @@ int main(int argc, char* argv[])
backendsMessage.c_str())
("data-dir,d", po::value<std::string>(&dataDir)->required(),
"Path to directory containing the ImageNet test data")
+ ("input-type,y", po::value(&inputType), "The data type of the input tensors."
+ "If unset, defaults to \"float\" for all defined inputs. "
+ "Accepted values (float, int or qasymm8)")
("input-name,i", po::value<std::string>(&inputName)->required(),
"Identifier of the input tensors in the network separated by comma.")
("output-name,o", po::value<std::string>(&outputName)->required(),
@@ -217,14 +159,50 @@ int main(int argc, char* argv[])
if(ValidateDirectory(dataDir))
{
+ InferenceModel<armnnDeserializer::IDeserializer, float>::Params params;
+ params.m_ModelPath = modelPath;
+ params.m_IsModelBinary = true;
+ params.m_ComputeDevices = computeDevice;
+ params.m_InputBindings.push_back(inputName);
+ params.m_OutputBindings.push_back(outputName);
+
+ using TParser = armnnDeserializer::IDeserializer;
+ InferenceModel<TParser, float> model(params, false);
for (auto & imageEntry : boost::make_iterator_range(directory_iterator(pathToDataDir), {}))
{
cout << "Processing image: " << imageEntry << "\n";
std::ifstream inputTensorFile(imageEntry.path().string());
vector<TContainer> inputDataContainers;
- inputDataContainers.push_back(ParseDataArray<armnn::DataType::Float32>(inputTensorFile));
- vector<TContainer> outputDataContainers = {vector<float>(1001)};
+ vector<TContainer> outputDataContainers;
+
+ if (inputType.compare("float") == 0)
+ {
+ inputDataContainers.push_back(
+ ParseDataArray<armnn::DataType::Float32>(inputTensorFile));
+ outputDataContainers = {vector<float>(1001)};
+ }
+ else if (inputType.compare("int") == 0)
+ {
+ inputDataContainers.push_back(
+ ParseDataArray<armnn::DataType::Signed32>(inputTensorFile));
+ outputDataContainers = {vector<int>(1001)};
+ }
+ else if (inputType.compare("qasymm8") == 0)
+ {
+ auto inputBinding = model.GetInputBindingInfo();
+ inputDataContainers.push_back(
+ ParseDataArray<armnn::DataType::QuantisedAsymm8>(
+ inputTensorFile,
+ inputBinding.second.GetQuantizationScale(),
+ inputBinding.second.GetQuantizationOffset()));
+ outputDataContainers = {vector<uint8_t >(1001)};
+ }
+ else
+ {
+ BOOST_LOG_TRIVIAL(fatal) << "Unsupported tensor data type \"" << inputType << "\". ";
+ return EXIT_FAILURE;
+ }
status = runtime->EnqueueWorkload(networkId,
armnnUtils::MakeInputTensors(inputBindings, inputDataContainers),