aboutsummaryrefslogtreecommitdiff
path: root/src/backends/cl/ClWorkloadFactory.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/cl/ClWorkloadFactory.cpp')
-rw-r--r--src/backends/cl/ClWorkloadFactory.cpp58
1 files changed, 50 insertions, 8 deletions
diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp
index 2f94ef0970..134dad576e 100644
--- a/src/backends/cl/ClWorkloadFactory.cpp
+++ b/src/backends/cl/ClWorkloadFactory.cpp
@@ -32,6 +32,8 @@
#include <armnnUtils/Filesystem.hpp>
#include <fstream>
+#include <sys/stat.h>
+
namespace armnn
{
@@ -67,13 +69,29 @@ void ClWorkloadFactory::AfterWorkloadsCreated()
auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get());
if (modelOptions->SaveCachedNetwork())
{
+ ClContextSerializer serializer;
+ serializer.Serialize(m_CLCompileContext);
+ auto cachedFd = modelOptions->GetCachedFileDescriptor();
+ if (cachedFd != -1)
+ {
+ std::vector<uint8_t> compiledContextData;
+ std::stringstream stream;
+ bool serialized = serializer.SaveSerializedToStream(stream);
+ if (serialized)
+ {
+ std::string const serializedString{stream.str()};
+ std::copy(serializedString.begin(),
+ serializedString.end(),
+ std::back_inserter(compiledContextData));
+ write(cachedFd, compiledContextData.data(), compiledContextData.size());
+ }
+ }
+
// Save map to a filepath provided in ModelOptions
auto filePath = modelOptions->GetCachedNetworkFilePath();
if (filePath != "" && fs::exists(filePath) && fs::is_regular_file(filePath))
{
// Serialize ClContext to the file specified
- ClContextSerializer serializer;
- serializer.Serialize(m_CLCompileContext);
std::ofstream file(filePath, std::ios::out | std::ios::binary);
serializer.SaveSerializedToStream(file);
}
@@ -123,14 +141,38 @@ void ClWorkloadFactory::InitializeCLCompileContext()
// Load saved programs if the user has set a filepath
auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get());
auto filePath = modelOptions->GetCachedNetworkFilePath();
- if (filePath != ""
- && fs::exists(filePath)
- && fs::is_regular_file(filePath)
- && !(modelOptions->SaveCachedNetwork()))
+ if (!(modelOptions->SaveCachedNetwork()))
{
- // Deserialize binary file and load into m_CLCompileContext
ClContextDeserializer deserializer;
- deserializer.Deserialize(m_CLCompileContext, context, device, filePath);
+ auto cachedFd = modelOptions->GetCachedFileDescriptor();
+ if (cachedFd != -1)
+ {
+ struct stat statBuffer;
+ if (fstat(cachedFd, &statBuffer) == 0)
+ {
+ long dataSize = static_cast<long>(statBuffer.st_size);
+ if( dataSize > 0)
+ {
+ auto offset = lseek(cachedFd, 0, SEEK_CUR);
+ if (offset == 0)
+ {
+ std::vector <uint8_t> compiledContextData(static_cast<unsigned int>(dataSize));
+ pread(cachedFd, compiledContextData.data(), compiledContextData.size(), 0);
+ deserializer.DeserializeFromBinary(m_CLCompileContext,
+ context,
+ device,
+ compiledContextData);
+ }
+ }
+
+ }
+ }
+
+ if (filePath != "" && fs::exists(filePath) && fs::is_regular_file(filePath))
+ {
+ // Deserialize binary file and load into m_CLCompileContext
+ deserializer.Deserialize(m_CLCompileContext, context, device, filePath);
+ }
}
}
}