aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorColm Donelan <colm.donelan@arm.com>2023-06-22 10:19:17 +0100
committerColm Donelan <colm.donelan@arm.com>2023-06-30 14:20:56 +0100
commit0dfb2658ce521571aa0f9e859f813c60fda9d8d6 (patch)
treeef4605cebb29c6754da8dfea0decd4250288e2a5
parent16e27cf81424dcad05d129f9ba368a8c446cd25f (diff)
downloadarmnn-0dfb2658ce521571aa0f9e859f813c60fda9d8d6.tar.gz
IVGCVSW-7666 Add a FileComparisonExecutor to ExecuteNetwork.
* Implement the "-C" command line option of executenetwork. * Add a FileComparisonExecutorFile which will read tensors from a previously written text file and compare them to the execution output. Signed-off-by: Colm Donelan <colm.donelan@arm.com> Change-Id: I8380fd263028af13d65a67fb6afd89626d1b07b8
-rw-r--r--CMakeLists.txt13
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp8
-rw-r--r--src/armnnTfLiteParser/test/TfLiteParser.cpp18
-rw-r--r--tests/CMakeLists.txt11
-rw-r--r--tests/ExecuteNetwork/ArmNNExecutor.cpp3
-rw-r--r--tests/ExecuteNetwork/ExecuteNetwork.cpp18
-rw-r--r--tests/ExecuteNetwork/ExecuteNetworkProgramOptions.cpp4
-rw-r--r--tests/ExecuteNetwork/FileComparisonExecutor.cpp344
-rw-r--r--tests/ExecuteNetwork/FileComparisonExecutor.hpp27
-rw-r--r--tests/ExecuteNetwork/TfliteExecutor.cpp50
-rw-r--r--tests/ExecuteNetwork/test/FileComparisonExecutorTests.cpp74
-rw-r--r--tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp14
12 files changed, 557 insertions, 27 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 6e53d45f81..a4199fe2d4 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -848,6 +848,9 @@ if(BUILD_UNIT_TESTS)
if(BUILD_TESTS)
list(APPEND unittest_sources
+ ./tests/ExecuteNetwork/FileComparisonExecutor.hpp
+ ./tests/ExecuteNetwork/FileComparisonExecutor.cpp
+ ./tests/ExecuteNetwork/test/FileComparisonExecutorTests.cpp
./tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp
./tests/NetworkExecutionUtils/test/NetworkExecutionUtilsTests.cpp)
endif()
@@ -858,11 +861,11 @@ if(BUILD_UNIT_TESTS)
endforeach()
add_executable(UnitTests ${unittest_sources})
- target_include_directories(UnitTests PRIVATE src/armnn)
- target_include_directories(UnitTests PRIVATE src/armnnUtils)
- target_include_directories(UnitTests PRIVATE src/armnnTestUtils)
- target_include_directories(UnitTests PRIVATE src/backends)
- target_include_directories(UnitTests PRIVATE src/profiling)
+ target_include_directories(UnitTests PRIVATE delegate/common/include
+ src/armnn src/armnnUtils
+ src/armnnTestUtils src/backends
+ src/profiling
+ tests)
if(VALGRIND_FOUND)
if(HEAP_PROFILING OR LEAK_CHECKING)
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 199a853918..1c5b4fc9f0 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -4876,9 +4876,15 @@ TfLiteParserImpl::ModelPtr TfLiteParserImpl::LoadModelFromFile(const char* fileN
std::stringstream msg;
msg << "Cannot find the file (" << fileName << ") errorCode: " << errorCode
<< " " << CHECK_LOCATION().AsString();
-
throw FileNotFoundException(msg.str());
}
+ if (!fs::is_regular_file(pathToFile))
+ {
+ // Exclude non regular files.
+ throw InvalidArgumentException(fmt::format("File \"{}\" is not a regular file and cannot be loaded.",
+ pathToFile.c_str()));
+ }
+
std::ifstream file(fileName, std::ios::binary);
std::string fileContent((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
return LoadModelFromBinary(reinterpret_cast<const uint8_t *>(fileContent.c_str()),
diff --git a/src/armnnTfLiteParser/test/TfLiteParser.cpp b/src/armnnTfLiteParser/test/TfLiteParser.cpp
index 65bbaeae0f..841c46e620 100644
--- a/src/armnnTfLiteParser/test/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/test/TfLiteParser.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2020, 2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -37,4 +37,20 @@ TEST_CASE_FIXTURE(NoInputBindingsFixture, "ParseBadInputBindings")
CHECK_THROWS_AS((RunTest<4, armnn::DataType::QAsymmU8>(0, { }, { 0 })), armnn::ParseException);
}
+TEST_CASE("ParseInvalidFileName")
+{
+ // Nullptr should throw InvalidArgumentException
+ CHECK_THROWS_AS(armnnTfLiteParser::TfLiteParserImpl::LoadModelFromFile(nullptr), armnn::InvalidArgumentException);
+ // Empty string should throw FileNotFoundException.
+ CHECK_THROWS_AS(armnnTfLiteParser::TfLiteParserImpl::LoadModelFromFile(""), armnn::FileNotFoundException);
+ // Garbage string should throw FileNotFoundException.
+ CHECK_THROWS_AS(armnnTfLiteParser::TfLiteParserImpl::LoadModelFromFile("askjfhseuirwqeuiy"),
+ armnn::FileNotFoundException);
+ // Valid directory should throw InvalidArgumentException
+ CHECK_THROWS_AS(armnnTfLiteParser::TfLiteParserImpl::LoadModelFromFile("."), armnn::InvalidArgumentException);
+ // Valid file but not a regular file should throw InvalidArgumentException
+ CHECK_THROWS_AS(armnnTfLiteParser::TfLiteParserImpl::LoadModelFromFile("/dev/null"),
+ armnn::InvalidArgumentException);
+}
+
}
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index 71374c4261..eca03b1bd2 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -152,6 +152,8 @@ if (BUILD_ARMNN_SERIALIZER
ExecuteNetwork/ExecuteNetworkProgramOptions.hpp
ExecuteNetwork/ExecuteNetworkParams.cpp
ExecuteNetwork/ExecuteNetworkParams.hpp
+ ExecuteNetwork/FileComparisonExecutor.cpp
+ ExecuteNetwork/FileComparisonExecutor.hpp
NetworkExecutionUtils/NetworkExecutionUtils.cpp
NetworkExecutionUtils/NetworkExecutionUtils.hpp)
@@ -164,10 +166,11 @@ if (BUILD_ARMNN_SERIALIZER
endif()
add_executable_ex(ExecuteNetwork ${ExecuteNetwork_sources})
- target_include_directories(ExecuteNetwork PRIVATE ../src/armnn)
- target_include_directories(ExecuteNetwork PRIVATE ../src/armnnUtils)
- target_include_directories(ExecuteNetwork PRIVATE ../src/backends)
- target_include_directories(ExecuteNetwork PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
+ target_include_directories(ExecuteNetwork PRIVATE ../src/armnn
+ ../src/armnnUtils
+ ../src/backends
+ ./NetworkExecutionUtils
+ ${CMAKE_CURRENT_SOURCE_DIR})
if(EXECUTE_NETWORK_STATIC)
target_link_libraries(ExecuteNetwork
diff --git a/tests/ExecuteNetwork/ArmNNExecutor.cpp b/tests/ExecuteNetwork/ArmNNExecutor.cpp
index 6955a492b3..4881dff7fb 100644
--- a/tests/ExecuteNetwork/ArmNNExecutor.cpp
+++ b/tests/ExecuteNetwork/ArmNNExecutor.cpp
@@ -663,7 +663,8 @@ void ArmNNExecutor::PrintOutputTensors(const armnn::OutputTensors* outputTensors
outputTensorFile,
bindingName,
output.second,
- !m_Params.m_DontPrintOutputs
+ !m_Params.m_DontPrintOutputs,
+ output.second.GetDataType()
};
std::cout << bindingName << ": ";
diff --git a/tests/ExecuteNetwork/ExecuteNetwork.cpp b/tests/ExecuteNetwork/ExecuteNetwork.cpp
index f9f583a9c6..9f81eb1168 100644
--- a/tests/ExecuteNetwork/ExecuteNetwork.cpp
+++ b/tests/ExecuteNetwork/ExecuteNetwork.cpp
@@ -3,14 +3,14 @@
// SPDX-License-Identifier: MIT
//
-#include "ExecuteNetworkProgramOptions.hpp"
#include "ArmNNExecutor.hpp"
+#include "ExecuteNetworkProgramOptions.hpp"
#if defined(ARMNN_TFLITE_DELEGATE) || defined(ARMNN_TFLITE_OPAQUE_DELEGATE)
#include "TfliteExecutor.hpp"
#endif
+#include "FileComparisonExecutor.hpp"
#include <armnn/Logging.hpp>
-
std::unique_ptr<IExecutor> BuildExecutor(ProgramOptions& programOptions)
{
if (programOptions.m_ExNetParams.m_TfLiteExecutor ==
@@ -42,7 +42,6 @@ int main(int argc, const char* argv[])
#endif
armnn::ConfigureLogging(true, true, level);
-
// Get ExecuteNetwork parameters and runtime options from command line
// This might throw an InvalidArgumentException if the user provided invalid inputs
ProgramOptions programOptions;
@@ -72,15 +71,14 @@ int main(int argc, const char* argv[])
return EXIT_FAILURE;
}
-
executor->PrintNetworkInfo();
outputResults = executor->Execute();
if (!programOptions.m_ExNetParams.m_ComparisonComputeDevices.empty() ||
- programOptions.m_ExNetParams.m_CompareWithTflite)
+ programOptions.m_ExNetParams.m_CompareWithTflite)
{
ExecuteNetworkParams comparisonParams = programOptions.m_ExNetParams;
- comparisonParams.m_ComputeDevices = programOptions.m_ExNetParams.m_ComparisonComputeDevices;
+ comparisonParams.m_ComputeDevices = programOptions.m_ExNetParams.m_ComparisonComputeDevices;
if (programOptions.m_ExNetParams.m_CompareWithTflite)
{
@@ -99,4 +97,12 @@ int main(int argc, const char* argv[])
comparisonExecutor->CompareAndPrintResult(outputResults);
}
+
+ // If there's a file comparison specified create a FileComparisonExecutor.
+ if (!programOptions.m_ExNetParams.m_ComparisonFile.empty())
+ {
+ FileComparisonExecutor comparisonExecutor(programOptions.m_ExNetParams);
+ comparisonExecutor.Execute();
+ comparisonExecutor.CompareAndPrintResult(outputResults);
+ }
}
diff --git a/tests/ExecuteNetwork/ExecuteNetworkProgramOptions.cpp b/tests/ExecuteNetwork/ExecuteNetworkProgramOptions.cpp
index 8d5035e103..7b55b28b8b 100644
--- a/tests/ExecuteNetwork/ExecuteNetworkProgramOptions.cpp
+++ b/tests/ExecuteNetwork/ExecuteNetworkProgramOptions.cpp
@@ -361,8 +361,8 @@ ProgramOptions::ProgramOptions() : m_CxxOptions{"ExecuteNetwork",
("C, compare-output",
"Perform a per byte root mean square error calculation of the inference output with an output"
- " file that has been previously produced by running a network through ExecuteNetwork."
- " See --write-outputs-to-file to produce an output file for an execution.",
+ " file(s) that has been previously produced by running a network through ExecuteNetwork."
+ " See --write-outputs-to-file to produce an output file(s) for an execution.",
cxxopts::value<std::string>(m_ExNetParams.m_ComparisonFile))
("B, compare-output-with-backend",
diff --git a/tests/ExecuteNetwork/FileComparisonExecutor.cpp b/tests/ExecuteNetwork/FileComparisonExecutor.cpp
new file mode 100644
index 0000000000..1675440e8c
--- /dev/null
+++ b/tests/ExecuteNetwork/FileComparisonExecutor.cpp
@@ -0,0 +1,344 @@
+//
+// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "FileComparisonExecutor.hpp"
+#include <NetworkExecutionUtils/NetworkExecutionUtils.hpp>
+#include <algorithm>
+#include <filesystem>
+#include <iterator>
+
+using namespace armnn;
+
+/**
+ * Given a buffer in the expected format. Extract from it the tensor name, tensor type as strings and return an
+ * index pointing to the start of the data section.
+ *
+ * @param buffer data to be parsed.
+ * @param tensorName the name of the tensor extracted from the header.
+ * @param tensorType the type of the tensor extracted from the header.
+ * @return index pointing to the start of the data in the buffer.
+ */
+unsigned int ExtractHeader(const std::vector<char>& buffer, std::string& tensorName, DataType& tensorType)
+{
+ auto isColon = [](char c) { return c == ':'; };
+ auto isComma = [](char c) { return c == ','; };
+
+ // Find the "," separator marks the end of the tensor name.
+ auto firstComma = std::find_if(buffer.begin(), buffer.end(), isComma);
+ if (firstComma == buffer.end())
+ {
+ throw ParseException("Unable to read tensor name from file.");
+ }
+ tensorName.assign(buffer.begin(), firstComma);
+
+ // The next colon marks the end of the data type string.
+ auto endOfHeader = std::find_if(firstComma, buffer.end(), isColon);
+ if (firstComma == buffer.end())
+ {
+ throw ParseException("Unable to read tensor type from file.");
+ }
+ std::string type(++firstComma, endOfHeader);
+ // Remove any leading or trailing whitespace.
+ type.erase(remove_if(type.begin(), type.end(), isspace), type.end());
+ if (type == "Float16")
+ {
+ tensorType = DataType::Float16;
+ }
+ else if (type == "Float32")
+ {
+ tensorType = DataType::Float32;
+ }
+ else if (type == "QAsymmU8")
+ {
+ tensorType = DataType::QAsymmU8;
+ }
+ else if (type == "Signed32")
+ {
+ tensorType = DataType::Signed32;
+ }
+ else if (type == "Boolean")
+ {
+ tensorType = DataType::Boolean;
+ }
+ else if (type == "QSymmS16")
+ {
+ tensorType = DataType::QSymmS16;
+ }
+ else if (type == "QSymmS8")
+ {
+ tensorType = DataType::QSymmS8;
+ }
+ else if (type == "QAsymmS8")
+ {
+ tensorType = DataType::QAsymmS8;
+ }
+ else if (type == "BFloat16")
+ {
+ tensorType = DataType::BFloat16;
+ }
+ else if (type == "Signed64")
+ {
+ tensorType = DataType::Signed64;
+ }
+ else
+ {
+ throw ParseException("Invalid data type in header.");
+ }
+ // Remember to move the iterator past the colon.
+ return (++endOfHeader - buffer.begin());
+}
+
+/**
+ * Extract the data from the file and return as a typed vector of elements.
+ *
+ * @param buffer data to be parsed.
+ * @param dataStart Index into the vector where the tensor data starts.
+ * @param tensorType the type of the tensor extracted from the header.
+ */
+template <typename T>
+void ReadData(const std::vector<char>& buffer,
+ const unsigned int dataStart,
+ const DataType& tensorType,
+ std::vector<T>& results)
+{
+ unsigned int index = dataStart;
+ while (index < buffer.size())
+ {
+ std::string elementString;
+ // Extract into a string until the next space.
+ while (index < buffer.size() && buffer[index] != ' ')
+ {
+ elementString.push_back(buffer[index]);
+ index++;
+ }
+ if (!elementString.empty())
+ {
+ switch (tensorType)
+ {
+ case DataType::Float32: {
+ results.push_back(std::stof(elementString));
+ break;
+ }
+
+ case DataType::Signed32: {
+ results.push_back(std::stoi(elementString));
+ break;
+ }
+ case DataType::QSymmS8:
+ case DataType::QAsymmS8: {
+ results.push_back(elementString[0]);
+ break;
+ }
+ case DataType::QAsymmU8: {
+ results.push_back(elementString[0]);
+ break;
+ }
+ case DataType::Float16:
+ case DataType::QSymmS16:
+ case DataType::BFloat16:
+ case DataType::Boolean:
+ case DataType::Signed64:
+ default: {
+ LogAndThrow("Unsupported DataType");
+ }
+ }
+ // Finally, skip the space we know is there.
+ index++;
+ }
+ else
+ {
+ if (index < buffer.size())
+ {
+ index++;
+ }
+ }
+ }
+}
+
+/**
+ * Open the given file and read the data out of it to construct a Tensor. This could throw FileNotFoundException
+ * or InvalidArgumentException
+ *
+ * @param fileName the file to be read.
+ * @return a populated tensor.
+ */
+Tensor ReadTensorFromFile(const std::string fileName)
+{
+ if (!std::filesystem::exists(fileName))
+ {
+ throw FileNotFoundException("The file \"" + fileName + "\" could not be found.");
+ }
+ // The format we are reading in is based on NetworkExecutionUtils::WriteToFile. This could potentially
+ // be an enormous tensor. We'll limit what we can read in to 1Mb.
+ std::uintmax_t maxFileSize = 1048576;
+ std::uintmax_t fileSize = std::filesystem::file_size(fileName);
+ if (fileSize > maxFileSize)
+ {
+ throw InvalidArgumentException("The file \"" + fileName + "\" exceeds max size of 1 Mb.");
+ }
+
+ // We'll read the entire file into one buffer.
+ std::ifstream file(fileName, std::ios::binary);
+ std::vector<char> buffer(fileSize);
+ if (file.read(buffer.data(), fileSize))
+ {
+ std::string tensorName;
+ DataType tensorType;
+ unsigned int tensorDataStart = ExtractHeader(buffer, tensorName, tensorType);
+ switch (tensorType)
+ {
+ case DataType::Float32: {
+ std::vector<float> floatVector;
+ ReadData(buffer, tensorDataStart, tensorType, floatVector);
+ TensorInfo info({ static_cast<unsigned int>(floatVector.size()), 1, 1, 1 }, DataType::Float32);
+ float* floats = new float[floatVector.size()];
+ memcpy(floats, floatVector.data(), (floatVector.size() * sizeof(float)));
+ return Tensor(info, floats);
+ }
+ case DataType::Signed32: {
+ std::vector<int> intVector;
+ ReadData(buffer, tensorDataStart, tensorType, intVector);
+ TensorInfo info({ static_cast<unsigned int>(intVector.size()), 1, 1, 1 }, DataType::Signed32);
+ int* ints = new int[intVector.size()];
+ memcpy(ints, intVector.data(), (intVector.size() * sizeof(float)));
+ return Tensor(info, ints);
+ }
+ case DataType::QSymmS8: {
+ std::vector<int8_t> intVector;
+ ReadData(buffer, tensorDataStart, tensorType, intVector);
+ TensorInfo info({ static_cast<unsigned int>(intVector.size()), 1, 1, 1 }, DataType::QSymmS8);
+ int8_t* ints = new int8_t[intVector.size()];
+ memcpy(ints, intVector.data(), (intVector.size() * sizeof(float)));
+ return Tensor(info, ints);
+ }
+ case DataType::QAsymmS8: {
+ std::vector<int8_t> intVector;
+ ReadData(buffer, tensorDataStart, tensorType, intVector);
+ TensorInfo info({ static_cast<unsigned int>(intVector.size()), 1, 1, 1 }, DataType::QAsymmS8);
+ int8_t* ints = new int8_t[intVector.size()];
+ memcpy(ints, intVector.data(), (intVector.size() * sizeof(float)));
+ return Tensor(info, ints);
+ }
+ case DataType::QAsymmU8: {
+ std::vector<uint8_t> intVector;
+ ReadData(buffer, tensorDataStart, tensorType, intVector);
+ TensorInfo info({ static_cast<unsigned int>(intVector.size()), 1, 1, 1 }, DataType::QAsymmU8);
+ uint8_t* ints = new uint8_t[intVector.size()];
+ memcpy(ints, intVector.data(), (intVector.size() * sizeof(float)));
+ return Tensor(info, ints);
+ }
+ default:
+ throw InvalidArgumentException("The tensor data could not be read from \"" + fileName + "\"");
+ }
+ }
+ else
+ {
+ throw ParseException("Filed to read the contents of \"" + fileName + "\"");
+ }
+
+ Tensor result;
+ return result;
+}
+
+FileComparisonExecutor::FileComparisonExecutor(const ExecuteNetworkParams& params)
+ : m_Params(params)
+{}
+
+std::vector<const void*> FileComparisonExecutor::Execute()
+{
+ std::string filesToCompare = this->m_Params.m_ComparisonFile;
+ if (filesToCompare.empty())
+ {
+ throw InvalidArgumentException("The file(s) to compare was not set.");
+ }
+ // filesToCompare is one or more files containing output tensors. Iterate and read in the tensors.
+ // We'll assume the string follows the same comma seperated format as write-outputs-to-file.
+ std::stringstream ss(filesToCompare);
+ std::vector<std::string> fileNames;
+ std::string errorString;
+ while (ss.good())
+ {
+ std::string substr;
+ getline(ss, substr, ',');
+ // Check the file exist.
+ if (!std::filesystem::exists(substr))
+ {
+ errorString += substr + " ";
+ }
+ else
+ {
+ fileNames.push_back(substr);
+ }
+ }
+ if (!errorString.empty())
+ {
+ throw FileNotFoundException("The following file(s) to compare could not be found: " + errorString);
+ }
+ // Read in the tensors into m_OutputTensorsVec
+ OutputTensors outputs;
+ std::vector<const void*> results;
+ for (auto file : fileNames)
+ {
+ Tensor t = ReadTensorFromFile(file);
+ outputs.push_back({ 0, Tensor(t.GetInfo(), t.GetMemoryArea()) });
+ results.push_back(t.GetMemoryArea());
+ }
+ m_OutputTensorsVec.push_back(outputs);
+ return results;
+}
+
+void FileComparisonExecutor::PrintNetworkInfo()
+{
+ std::cout << "Not implemented in this class." << std::endl;
+}
+
+void FileComparisonExecutor::CompareAndPrintResult(std::vector<const void*> otherOutput)
+{
+ unsigned int index = 0;
+ std::string typeString;
+ for (const auto& outputTensors : m_OutputTensorsVec)
+ {
+ for (const auto& outputTensor : outputTensors)
+ {
+ size_t size = outputTensor.second.GetNumBytes();
+ double result = ComputeByteLevelRMSE(outputTensor.second.GetMemoryArea(), otherOutput[index++], size);
+ std::cout << "Byte level root mean square error: " << result << "\n";
+ }
+ }
+}
+
+FileComparisonExecutor::~FileComparisonExecutor()
+{
+ // If there are tensors defined in m_OutputTensorsVec we need to clean up their memory usage.
+ for (OutputTensors opTensor : m_OutputTensorsVec)
+ {
+ for (std::pair<LayerBindingId, class Tensor> pair : opTensor)
+ {
+ Tensor t = pair.second;
+ // Based on the tensor type and size recover the memory.
+ switch (t.GetDataType())
+ {
+ case DataType::Float32:
+ delete[] static_cast<float*>(t.GetMemoryArea());
+ break;
+ case DataType::Signed32:
+ delete[] static_cast<int*>(t.GetMemoryArea());
+ break;
+ case DataType::QSymmS8:
+ delete[] static_cast<int8_t*>(t.GetMemoryArea());
+ break;
+ case DataType::QAsymmS8:
+ delete[] static_cast<int8_t*>(t.GetMemoryArea());
+ break;
+ case DataType::QAsymmU8:
+ delete[] static_cast<uint8_t*>(t.GetMemoryArea());
+ break;
+ default:
+ std::cout << "The data type wasn't created in ReadTensorFromFile" << std::endl;
+ }
+ }
+ }
+
+}
diff --git a/tests/ExecuteNetwork/FileComparisonExecutor.hpp b/tests/ExecuteNetwork/FileComparisonExecutor.hpp
new file mode 100644
index 0000000000..04e0339504
--- /dev/null
+++ b/tests/ExecuteNetwork/FileComparisonExecutor.hpp
@@ -0,0 +1,27 @@
+//
+// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "ExecuteNetworkProgramOptions.hpp"
+#include "IExecutor.hpp"
+
+class FileComparisonExecutor : public IExecutor
+{
+public:
+ FileComparisonExecutor(const ExecuteNetworkParams& params);
+ ~FileComparisonExecutor();
+ std::vector<const void*> Execute() override;
+ void PrintNetworkInfo() override;
+ void CompareAndPrintResult(std::vector<const void*> otherOutput) override;
+
+private:
+ // Disallow copy and assignment constructors.
+ FileComparisonExecutor(FileComparisonExecutor&);
+ FileComparisonExecutor operator=(const FileComparisonExecutor&);
+
+ ExecuteNetworkParams m_Params;
+ std::vector<armnn::OutputTensors> m_OutputTensorsVec;
+};
diff --git a/tests/ExecuteNetwork/TfliteExecutor.cpp b/tests/ExecuteNetwork/TfliteExecutor.cpp
index 6455650404..04e510f938 100644
--- a/tests/ExecuteNetwork/TfliteExecutor.cpp
+++ b/tests/ExecuteNetwork/TfliteExecutor.cpp
@@ -11,6 +11,50 @@
#include "TfliteExecutor.hpp"
#include "tensorflow/lite/kernels/kernel_util.h"
+#include <string>
+
+std::string TfLiteStatusToString(const TfLiteStatus status)
+{
+ switch (status)
+ {
+ case kTfLiteOk:
+ return "Status: Ok.";
+ // Generally referring to an error in the runtime (i.e. interpreter)
+ case kTfLiteError:
+ return "Status: Tf runtime error.";
+ // Generally referring to an error from a TfLiteDelegate itself.
+ case kTfLiteDelegateError:
+ return "Status: The loaded delegate has returned an error.";
+ // Generally referring to an error in applying a delegate due to
+ // incompatibility between runtime and delegate, e.g., this error is returned
+ // when trying to apply a TF Lite delegate onto a model graph that's already
+ // immutable.
+ case kTfLiteApplicationError:
+ return "Status: Application error. An incompatibility between the Tf runtime and the loaded delegate.";
+ // Generally referring to serialized delegate data not being found.
+ // See tflite::delegates::Serialization.
+ case kTfLiteDelegateDataNotFound:
+ return "Status: data not found.";
+ // Generally referring to data-writing issues in delegate serialization.
+ // See tflite::delegates::Serialization.
+ case kTfLiteDelegateDataWriteError:
+ return "Status: Error writing serialization data.";
+ // Generally referring to data-reading issues in delegate serialization.
+ // See tflite::delegates::Serialization.
+ case kTfLiteDelegateDataReadError:
+ return "Status: Error reading serialization data.";
+ // Generally referring to issues when the TF Lite model has ops that cannot be
+ // resolved at runtime. This could happen when the specific op is not
+ // registered or built with the TF Lite framework.
+ case kTfLiteUnresolvedOps:
+ return "Status: Model contains an operation that is not recognised by the runtime.";
+ // Generally referring to invocation cancelled by the user.
+ case kTfLiteCancelled:
+ return "Status: invocation has been cancelled by the user.";
+ }
+ return "Unknown status result.";
+}
+
TfLiteExecutor::TfLiteExecutor(const ExecuteNetworkParams& params, armnn::IRuntime::CreationOptions runtimeOptions)
: m_Params(params)
{
@@ -67,9 +111,11 @@ TfLiteExecutor::TfLiteExecutor(const ExecuteNetworkParams& params, armnn::IRunti
theArmnnDelegate(armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions),
armnnDelegate::TfLiteArmnnDelegateDelete);
// Register armnn_delegate to TfLiteInterpreter
- if (m_TfLiteInterpreter->ModifyGraphWithDelegate(std::move(theArmnnDelegate)) != kTfLiteOk)
+ auto result = m_TfLiteInterpreter->ModifyGraphWithDelegate(std::move(theArmnnDelegate));
+ if (result != kTfLiteOk)
{
- LogAndThrow("Could not register ArmNN TfLite Delegate to TfLiteInterpreter.");
+ LogAndThrow("Could not register ArmNN TfLite Delegate to TfLiteInterpreter: " +
+ TfLiteStatusToString(result) + ".");
}
#else
LogAndThrow("Not built with Arm NN Tensorflow-Lite delegate support.");
diff --git a/tests/ExecuteNetwork/test/FileComparisonExecutorTests.cpp b/tests/ExecuteNetwork/test/FileComparisonExecutorTests.cpp
new file mode 100644
index 0000000000..c8a7171107
--- /dev/null
+++ b/tests/ExecuteNetwork/test/FileComparisonExecutorTests.cpp
@@ -0,0 +1,74 @@
+//
+// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include <ExecuteNetwork/FileComparisonExecutor.hpp>
+#include <doctest/doctest.h>
+#include <filesystem>
+#include <fstream>
+namespace
+{
+
+namespace fs = std::filesystem;
+
+TEST_SUITE("FileComparisonExecutorTests")
+{
+
+ TEST_CASE("EmptyComparisonThrowsException")
+ {
+ ExecuteNetworkParams params;
+ FileComparisonExecutor classToTest(params);
+ // The comparison file is not set in the parameters. This should throw an exception.
+ CHECK_THROWS_AS(classToTest.Execute(), armnn::InvalidArgumentException);
+ }
+
+ TEST_CASE("InvalidComparisonFilesThrowsException")
+ {
+ ExecuteNetworkParams params;
+ params.m_ComparisonFile = "Balh,Blah,Blah";
+ FileComparisonExecutor classToTest(params);
+ // None of the files in the parameter exist.
+ CHECK_THROWS_AS(classToTest.Execute(), armnn::FileNotFoundException);
+ }
+
+ TEST_CASE("ComparisonFileIsEmpty")
+ {
+ std::filesystem::path fileName = fs::temp_directory_path().append("ComparisonFileIsEmpty.tmp");
+ std::fstream tmpFile;
+ tmpFile.open(fileName, std::ios::out);
+ ExecuteNetworkParams params;
+ params.m_ComparisonFile = fileName;
+ FileComparisonExecutor classToTest(params);
+ // The comparison file is empty. This exception should happen in ExtractHeader when it realises it
+ // can't read a header.
+ CHECK_THROWS_AS(classToTest.Execute(), armnn::ParseException);
+ tmpFile.close();
+ std::filesystem::remove(fileName);
+ }
+
+ TEST_CASE("ComparisonFileHasValidHeaderAndData")
+ {
+ std::filesystem::path fileName = fs::temp_directory_path().append("ComparisonFileHasValidHeaderAndData.tmp");
+ std::fstream tmpFile;
+ tmpFile.open(fileName, std::ios::out);
+ // Write a valid header.
+ tmpFile << "TensorName, Float32 : 1.1000";
+ tmpFile.close();
+ ExecuteNetworkParams params;
+ params.m_ComparisonFile = fileName;
+ FileComparisonExecutor classToTest(params);
+ // The read in tensor should consist of 1 float.
+ std::vector<const void*> results = classToTest.Execute();
+ std::filesystem::remove(fileName);
+ // Should be one tensor in the data.
+ CHECK_EQ(1, results.size());
+ // We expect there to be 1 element of value 1.1f.
+ const float* floatPtr = static_cast<const float*>(results[0]);
+ CHECK_EQ(*floatPtr, 1.1f);
+ }
+
+
+} // End of TEST_SUITE("FileComparisonExecutorTests")
+
+} // anonymous namespace \ No newline at end of file
diff --git a/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp b/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp
index 2136c446fb..0df3bf5ef5 100644
--- a/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp
+++ b/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2022, 2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -184,13 +184,14 @@ template<typename T>
void WriteToFile(const std::string& outputTensorFileName,
const std::string& outputName,
const T* const array,
- const unsigned int numElements)
+ const unsigned int numElements,
+ armnn::DataType dataType)
{
std::ofstream outputTensorFile;
outputTensorFile.open(outputTensorFileName, std::ofstream::out | std::ofstream::trunc);
if (outputTensorFile.is_open())
{
- outputTensorFile << outputName << ": ";
+ outputTensorFile << outputName << ", "<< GetDataTypeName(dataType) << " : ";
for (std::size_t i = 0; i < numElements; ++i)
{
outputTensorFile << +array[i] << " ";
@@ -209,6 +210,7 @@ struct OutputWriteInfo
const std::string& m_OutputName;
const armnn::Tensor& m_Tensor;
const bool m_PrintTensor;
+ const armnn::DataType m_DataType;
};
template <typename T>
@@ -221,7 +223,8 @@ void PrintTensor(OutputWriteInfo& info, const char* formatString)
WriteToFile(info.m_OutputTensorFile.value(),
info.m_OutputName,
array,
- info.m_Tensor.GetNumElements());
+ info.m_Tensor.GetNumElements(),
+ info.m_DataType);
}
if (info.m_PrintTensor)
@@ -248,7 +251,8 @@ void PrintQuantizedTensor(OutputWriteInfo& info)
WriteToFile(info.m_OutputTensorFile.value(),
info.m_OutputName,
dequantizedValues.data(),
- tensor.GetNumElements());
+ tensor.GetNumElements(),
+ info.m_DataType);
}
if (info.m_PrintTensor)