diff options
Diffstat (limited to 'tests/ExecuteNetwork/ArmNNExecutor.cpp')
-rw-r--r-- | tests/ExecuteNetwork/ArmNNExecutor.cpp | 152 |
1 files changed, 145 insertions, 7 deletions
diff --git a/tests/ExecuteNetwork/ArmNNExecutor.cpp b/tests/ExecuteNetwork/ArmNNExecutor.cpp index 4518f1426f..2004bb1ec0 100644 --- a/tests/ExecuteNetwork/ArmNNExecutor.cpp +++ b/tests/ExecuteNetwork/ArmNNExecutor.cpp @@ -7,27 +7,157 @@ #include "ArmNNExecutor.hpp" #include "NetworkExecutionUtils/NetworkExecutionUtils.hpp" -#include <armnn/IAsyncExecutionCallback.hpp> #include <AsyncExecutionCallback.hpp> - - +#include <armnn/IAsyncExecutionCallback.hpp> +#if defined(ARMNN_SERIALIZER) +#include <armnnSerializer/ISerializer.hpp> +#endif using namespace armnn; using namespace std::chrono; +#if defined(ARMNN_SERIALIZER) +/** + * Given a reference to an INetwork and a target directory, serialize the network to a file + * called "<timestamp>_network.armnn" + * + * @param network The network to serialize. + * @param dumpDir The target directory. + * @return the full path to the serialized file. + */ +std::string SerializeNetwork(const armnn::INetwork& network, const std::string& dumpDir) +{ + if (dumpDir.empty()) + { + throw InvalidArgumentException("An output directory must be specified."); + } + fs::path outputDirectory(dumpDir); + if (!exists(outputDirectory)) + { + throw InvalidArgumentException( + fmt::format("The specified directory does not exist: {}", outputDirectory.c_str())); + } + auto serializer(armnnSerializer::ISerializer::Create()); + // Serialize the Network + serializer->Serialize(network); + + fs::path fileName; + fileName += dumpDir; + // used to get a timestamp to name diagnostic files (the ArmNN serialized graph + // and getSupportedOperations.txt files) + timespec ts; + if (clock_gettime(CLOCK_MONOTONIC_RAW, &ts) == 0) + { + std::stringstream ss; + ss << std::to_string(ts.tv_sec) << "_" << std::to_string(ts.tv_nsec) << "_network.armnn"; + fileName += ss.str(); + } + else + { + // This is incredibly unlikely but just in case. + throw RuntimeException("clock_gettime, CLOCK_MONOTONIC_RAW returned a non zero result."); + } + + // Save serialized network to a file + std::ofstream serializedFile(fileName, std::ios::out | std::ios::binary); + auto serialized = serializer->SaveSerializedToStream(serializedFile); + if (!serialized) + { + throw RuntimeException(fmt::format("An error occurred when serializing to file %s", fileName.c_str())); + } + serializedFile.flush(); + serializedFile.close(); + return fileName; +} + +/** + * Given a reference to an optimized network and a target directory, serialize the network in .dot file format to + * a file called "<timestamp>_optimized_networkgraph.dot" + * + * @param network The network to serialize. + * @param dumpDir The target directory. + * @return the full path to the serialized file. + */ +std::string SerializeNetworkToDotFile(const armnn::IOptimizedNetwork& optimizedNetwork, const std::string& dumpDir) +{ + if (dumpDir.empty()) + { + throw InvalidArgumentException("An output directory must be specified."); + } + fs::path outputDirectory(dumpDir); + if (!exists(outputDirectory)) + { + throw InvalidArgumentException( + fmt::format("The specified directory does not exist: {}", outputDirectory.c_str())); + } + + fs::path fileName; + fileName += dumpDir; + // used to get a timestamp to name diagnostic files (the ArmNN serialized graph + // and getSupportedOperations.txt files) + timespec ts; + if (clock_gettime(CLOCK_MONOTONIC_RAW, &ts) == 0) + { + std::stringstream ss; + ss << std::to_string(ts.tv_sec) << "_" << std::to_string(ts.tv_nsec) << "_optimized_networkgraph.dot"; + fileName += ss.str(); + } + else + { + // This is incredibly unlikely but just in case. + throw RuntimeException("clock_gettime, CLOCK_MONOTONIC_RAW returned a non zero result."); + } + + // Write the network graph to a dot file. + std::ofstream fileStream; + fileStream.open(fileName, std::ofstream::out | std::ofstream::trunc); + if (!fileStream.good()) + { + throw RuntimeException(fmt::format("An error occurred when creating %s", fileName.c_str())); + } + + if (optimizedNetwork.SerializeToDot(fileStream) != armnn::Status::Success) + { + throw RuntimeException(fmt::format("An error occurred when serializing to file %s", fileName.c_str())); + } + fileStream.flush(); + fileStream.close(); + return fileName; +} +#endif + ArmNNExecutor::ArmNNExecutor(const ExecuteNetworkParams& params, armnn::IRuntime::CreationOptions runtimeOptions) -: m_Params(params) + : m_Params(params) { - runtimeOptions.m_EnableGpuProfiling = params.m_EnableProfiling; + runtimeOptions.m_EnableGpuProfiling = params.m_EnableProfiling; runtimeOptions.m_DynamicBackendsPath = params.m_DynamicBackendsPath; // Create/Get the static ArmNN Runtime. Note that the m_Runtime will be shared by all ArmNNExecutor // instances so the RuntimeOptions cannot be altered for different ArmNNExecutor instances. m_Runtime = GetRuntime(runtimeOptions); - auto parser = CreateParser(); + auto parser = CreateParser(); auto network = parser->CreateNetwork(m_Params); - auto optNet = OptimizeNetwork(network.get()); + auto optNet = OptimizeNetwork(network.get()); + // If the user has asked for detailed data write out the .armnn amd .dot files. + if (params.m_SerializeToArmNN) + { +#if defined(ARMNN_SERIALIZER) + // .armnn first. + // This could throw multiple exceptions if the directory cannot be created or the file cannot be written. + std::string targetDirectory(armnnUtils::Filesystem::CreateDirectory("/ArmNNSerializeNetwork")); + std::string fileName; + fileName = SerializeNetwork(*network, targetDirectory); + ARMNN_LOG(info) << "The pre-optimized network has been serialized to:" << fileName; + // and the .dot file. + // Most of the possible exceptions should have already occurred with the .armnn file. + fileName = + SerializeNetworkToDotFile(*optNet, targetDirectory); + ARMNN_LOG(info) << "The optimized network has been serialized to:" << fileName; +#else + ARMNN_LOG(info) << "Arm NN has not been built with ARMNN_SERIALIZER enabled."; +#endif + } m_IOInfo = GetIOInfo(optNet.get()); armnn::ProfilingDetailsMethod profilingDetailsMethod = ProfilingDetailsMethod::Undefined; @@ -176,6 +306,12 @@ void ArmNNExecutor::ExecuteAsync() void ArmNNExecutor::ExecuteSync() { + // If we've only been asked to serialize the networks, don't execute the inference. + if (m_Params.m_SerializeToArmNN) + { + ARMNN_LOG(info) << "serialize-to-armnn has been specified. No inference will be executed."; + return; + } for (size_t x = 0; x < m_Params.m_Iterations; x++) { std::shared_ptr<armnn::IProfiler> profiler = m_Runtime->GetProfiler(m_NetworkId); @@ -800,6 +936,7 @@ armnn::BindingPointInfo ArmNNExecutor::TfliteParser::GetOutputBindingPointInfo(s #if defined(ARMNN_ONNX_PARSER) +ARMNN_NO_DEPRECATE_WARN_BEGIN ArmNNExecutor::OnnxParser::OnnxParser() : m_Parser(armnnOnnxParser::IOnnxParser::Create()){} armnn::INetworkPtr ArmNNExecutor::OnnxParser::CreateNetwork(const ExecuteNetworkParams& params) @@ -843,4 +980,5 @@ armnn::BindingPointInfo ArmNNExecutor::OnnxParser::GetOutputBindingPointInfo(siz { return m_Parser->GetNetworkOutputBindingInfo(outputName); } +ARMNN_NO_DEPRECATE_WARN_END #endif |