diff options
Diffstat (limited to 'src/backends/cl/ClWorkloadFactory.cpp')
-rw-r--r-- | src/backends/cl/ClWorkloadFactory.cpp | 58 |
1 files changed, 50 insertions, 8 deletions
diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp index 2f94ef0970..134dad576e 100644 --- a/src/backends/cl/ClWorkloadFactory.cpp +++ b/src/backends/cl/ClWorkloadFactory.cpp @@ -32,6 +32,8 @@ #include <armnnUtils/Filesystem.hpp> #include <fstream> +#include <sys/stat.h> + namespace armnn { @@ -67,13 +69,29 @@ void ClWorkloadFactory::AfterWorkloadsCreated() auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get()); if (modelOptions->SaveCachedNetwork()) { + ClContextSerializer serializer; + serializer.Serialize(m_CLCompileContext); + auto cachedFd = modelOptions->GetCachedFileDescriptor(); + if (cachedFd != -1) + { + std::vector<uint8_t> compiledContextData; + std::stringstream stream; + bool serialized = serializer.SaveSerializedToStream(stream); + if (serialized) + { + std::string const serializedString{stream.str()}; + std::copy(serializedString.begin(), + serializedString.end(), + std::back_inserter(compiledContextData)); + write(cachedFd, compiledContextData.data(), compiledContextData.size()); + } + } + // Save map to a filepath provided in ModelOptions auto filePath = modelOptions->GetCachedNetworkFilePath(); if (filePath != "" && fs::exists(filePath) && fs::is_regular_file(filePath)) { // Serialize ClContext to the file specified - ClContextSerializer serializer; - serializer.Serialize(m_CLCompileContext); std::ofstream file(filePath, std::ios::out | std::ios::binary); serializer.SaveSerializedToStream(file); } @@ -123,14 +141,38 @@ void ClWorkloadFactory::InitializeCLCompileContext() // Load saved programs if the user has set a filepath auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get()); auto filePath = modelOptions->GetCachedNetworkFilePath(); - if (filePath != "" - && fs::exists(filePath) - && fs::is_regular_file(filePath) - && !(modelOptions->SaveCachedNetwork())) + if (!(modelOptions->SaveCachedNetwork())) { - // Deserialize binary file and load into m_CLCompileContext ClContextDeserializer deserializer; - deserializer.Deserialize(m_CLCompileContext, context, device, filePath); + auto cachedFd = modelOptions->GetCachedFileDescriptor(); + if (cachedFd != -1) + { + struct stat statBuffer; + if (fstat(cachedFd, &statBuffer) == 0) + { + long dataSize = static_cast<long>(statBuffer.st_size); + if( dataSize > 0) + { + auto offset = lseek(cachedFd, 0, SEEK_CUR); + if (offset == 0) + { + std::vector <uint8_t> compiledContextData(static_cast<unsigned int>(dataSize)); + pread(cachedFd, compiledContextData.data(), compiledContextData.size(), 0); + deserializer.DeserializeFromBinary(m_CLCompileContext, + context, + device, + compiledContextData); + } + } + + } + } + + if (filePath != "" && fs::exists(filePath) && fs::is_regular_file(filePath)) + { + // Deserialize binary file and load into m_CLCompileContext + deserializer.Deserialize(m_CLCompileContext, context, device, filePath); + } } } } |