diff options
author | Sadik Armagan <sadik.armagan@arm.com> | 2021-10-06 16:37:02 +0100 |
---|---|---|
committer | Sadik Armagan <sadik.armagan@arm.com> | 2021-10-28 08:17:26 +0000 |
commit | b7851f9b29dbbe995acf6dc271471e73261c196a (patch) | |
tree | b00b29ea8e4161b11473053a0957f4a902aca837 /src/backends/cl/ClWorkloadFactory.cpp | |
parent | af3a4ef77d8f330a995911b979417857514df62c (diff) | |
download | armnn-b7851f9b29dbbe995acf6dc271471e73261c196a.tar.gz |
IVGCVSW-5636 'Implement NNAPI caching functions'
* Get number of inputs and outputs from optimized network.
* Get number of cached files if backend supports caching.
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Change-Id: Ie02ac123bb7df9b0593a2fe46b5bb564a5994780
Diffstat (limited to 'src/backends/cl/ClWorkloadFactory.cpp')
-rw-r--r-- | src/backends/cl/ClWorkloadFactory.cpp | 58 |
1 files changed, 50 insertions, 8 deletions
diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp index 2f94ef0970..134dad576e 100644 --- a/src/backends/cl/ClWorkloadFactory.cpp +++ b/src/backends/cl/ClWorkloadFactory.cpp @@ -32,6 +32,8 @@ #include <armnnUtils/Filesystem.hpp> #include <fstream> +#include <sys/stat.h> + namespace armnn { @@ -67,13 +69,29 @@ void ClWorkloadFactory::AfterWorkloadsCreated() auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get()); if (modelOptions->SaveCachedNetwork()) { + ClContextSerializer serializer; + serializer.Serialize(m_CLCompileContext); + auto cachedFd = modelOptions->GetCachedFileDescriptor(); + if (cachedFd != -1) + { + std::vector<uint8_t> compiledContextData; + std::stringstream stream; + bool serialized = serializer.SaveSerializedToStream(stream); + if (serialized) + { + std::string const serializedString{stream.str()}; + std::copy(serializedString.begin(), + serializedString.end(), + std::back_inserter(compiledContextData)); + write(cachedFd, compiledContextData.data(), compiledContextData.size()); + } + } + // Save map to a filepath provided in ModelOptions auto filePath = modelOptions->GetCachedNetworkFilePath(); if (filePath != "" && fs::exists(filePath) && fs::is_regular_file(filePath)) { // Serialize ClContext to the file specified - ClContextSerializer serializer; - serializer.Serialize(m_CLCompileContext); std::ofstream file(filePath, std::ios::out | std::ios::binary); serializer.SaveSerializedToStream(file); } @@ -123,14 +141,38 @@ void ClWorkloadFactory::InitializeCLCompileContext() // Load saved programs if the user has set a filepath auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get()); auto filePath = modelOptions->GetCachedNetworkFilePath(); - if (filePath != "" - && fs::exists(filePath) - && fs::is_regular_file(filePath) - && !(modelOptions->SaveCachedNetwork())) + if (!(modelOptions->SaveCachedNetwork())) { - // Deserialize binary file and load into m_CLCompileContext ClContextDeserializer deserializer; - deserializer.Deserialize(m_CLCompileContext, context, device, filePath); + auto cachedFd = modelOptions->GetCachedFileDescriptor(); + if (cachedFd != -1) + { + struct stat statBuffer; + if (fstat(cachedFd, &statBuffer) == 0) + { + long dataSize = static_cast<long>(statBuffer.st_size); + if( dataSize > 0) + { + auto offset = lseek(cachedFd, 0, SEEK_CUR); + if (offset == 0) + { + std::vector <uint8_t> compiledContextData(static_cast<unsigned int>(dataSize)); + pread(cachedFd, compiledContextData.data(), compiledContextData.size(), 0); + deserializer.DeserializeFromBinary(m_CLCompileContext, + context, + device, + compiledContextData); + } + } + + } + } + + if (filePath != "" && fs::exists(filePath) && fs::is_regular_file(filePath)) + { + // Deserialize binary file and load into m_CLCompileContext + deserializer.Deserialize(m_CLCompileContext, context, device, filePath); + } } } } |