aboutsummaryrefslogtreecommitdiff
path: root/src/backends
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends')
-rw-r--r--src/backends/cl/ClBackend.hpp2
-rw-r--r--src/backends/cl/ClBackendModelContext.cpp11
-rw-r--r--src/backends/cl/ClBackendModelContext.hpp3
-rw-r--r--src/backends/cl/ClWorkloadFactory.cpp58
4 files changed, 65 insertions, 9 deletions
diff --git a/src/backends/cl/ClBackend.hpp b/src/backends/cl/ClBackend.hpp
index e0708d18e2..ffce800261 100644
--- a/src/backends/cl/ClBackend.hpp
+++ b/src/backends/cl/ClBackend.hpp
@@ -105,6 +105,8 @@ public:
return m_UsingCustomAllocator;
}
+ virtual unsigned int GetNumberOfCacheFiles() const override { return 1; }
+
// Cl requires a arm_compute::IAllocator we wrap the Arm NN ICustomAllocator to achieve this
class ClBackendCustomAllocatorWrapper : public arm_compute::IAllocator
{
diff --git a/src/backends/cl/ClBackendModelContext.cpp b/src/backends/cl/ClBackendModelContext.cpp
index b685bc296c..75a2e05bda 100644
--- a/src/backends/cl/ClBackendModelContext.cpp
+++ b/src/backends/cl/ClBackendModelContext.cpp
@@ -32,7 +32,7 @@ namespace armnn
{
ClBackendModelContext::ClBackendModelContext(const ModelOptions& modelOptions)
- : m_CachedNetworkFilePath(""), m_IsFastMathEnabled(false), m_SaveCachedNetwork(false)
+ : m_CachedNetworkFilePath(""), m_IsFastMathEnabled(false), m_SaveCachedNetwork(false), m_CachedFileDescriptor(-1)
{
if (!modelOptions.empty())
{
@@ -50,6 +50,10 @@ ClBackendModelContext::ClBackendModelContext(const ModelOptions& modelOptions)
{
m_CachedNetworkFilePath = ParseFile(value, "");
}
+ if (name == "CachedFileDescriptor")
+ {
+ m_CachedFileDescriptor = armnn::ParseIntBackendOption(value, -1);
+ }
});
}
}
@@ -69,4 +73,9 @@ bool ClBackendModelContext::SaveCachedNetwork() const
return m_SaveCachedNetwork;
}
+int ClBackendModelContext::GetCachedFileDescriptor() const
+{
+ return m_CachedFileDescriptor;
+}
+
} // namespace armnn \ No newline at end of file
diff --git a/src/backends/cl/ClBackendModelContext.hpp b/src/backends/cl/ClBackendModelContext.hpp
index e7a26cd688..4feb6aa452 100644
--- a/src/backends/cl/ClBackendModelContext.hpp
+++ b/src/backends/cl/ClBackendModelContext.hpp
@@ -36,10 +36,13 @@ public:
bool SaveCachedNetwork() const;
+ int GetCachedFileDescriptor() const;
+
private:
std::string m_CachedNetworkFilePath;
bool m_IsFastMathEnabled;
bool m_SaveCachedNetwork;
+ int m_CachedFileDescriptor;
};
diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp
index 2f94ef0970..134dad576e 100644
--- a/src/backends/cl/ClWorkloadFactory.cpp
+++ b/src/backends/cl/ClWorkloadFactory.cpp
@@ -32,6 +32,8 @@
#include <armnnUtils/Filesystem.hpp>
#include <fstream>
+#include <sys/stat.h>
+
namespace armnn
{
@@ -67,13 +69,29 @@ void ClWorkloadFactory::AfterWorkloadsCreated()
auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get());
if (modelOptions->SaveCachedNetwork())
{
+ ClContextSerializer serializer;
+ serializer.Serialize(m_CLCompileContext);
+ auto cachedFd = modelOptions->GetCachedFileDescriptor();
+ if (cachedFd != -1)
+ {
+ std::vector<uint8_t> compiledContextData;
+ std::stringstream stream;
+ bool serialized = serializer.SaveSerializedToStream(stream);
+ if (serialized)
+ {
+ std::string const serializedString{stream.str()};
+ std::copy(serializedString.begin(),
+ serializedString.end(),
+ std::back_inserter(compiledContextData));
+ write(cachedFd, compiledContextData.data(), compiledContextData.size());
+ }
+ }
+
// Save map to a filepath provided in ModelOptions
auto filePath = modelOptions->GetCachedNetworkFilePath();
if (filePath != "" && fs::exists(filePath) && fs::is_regular_file(filePath))
{
// Serialize ClContext to the file specified
- ClContextSerializer serializer;
- serializer.Serialize(m_CLCompileContext);
std::ofstream file(filePath, std::ios::out | std::ios::binary);
serializer.SaveSerializedToStream(file);
}
@@ -123,14 +141,38 @@ void ClWorkloadFactory::InitializeCLCompileContext()
// Load saved programs if the user has set a filepath
auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get());
auto filePath = modelOptions->GetCachedNetworkFilePath();
- if (filePath != ""
- && fs::exists(filePath)
- && fs::is_regular_file(filePath)
- && !(modelOptions->SaveCachedNetwork()))
+ if (!(modelOptions->SaveCachedNetwork()))
{
- // Deserialize binary file and load into m_CLCompileContext
ClContextDeserializer deserializer;
- deserializer.Deserialize(m_CLCompileContext, context, device, filePath);
+ auto cachedFd = modelOptions->GetCachedFileDescriptor();
+ if (cachedFd != -1)
+ {
+ struct stat statBuffer;
+ if (fstat(cachedFd, &statBuffer) == 0)
+ {
+ long dataSize = static_cast<long>(statBuffer.st_size);
+ if( dataSize > 0)
+ {
+ auto offset = lseek(cachedFd, 0, SEEK_CUR);
+ if (offset == 0)
+ {
+ std::vector <uint8_t> compiledContextData(static_cast<unsigned int>(dataSize));
+ pread(cachedFd, compiledContextData.data(), compiledContextData.size(), 0);
+ deserializer.DeserializeFromBinary(m_CLCompileContext,
+ context,
+ device,
+ compiledContextData);
+ }
+ }
+
+ }
+ }
+
+ if (filePath != "" && fs::exists(filePath) && fs::is_regular_file(filePath))
+ {
+ // Deserialize binary file and load into m_CLCompileContext
+ deserializer.Deserialize(m_CLCompileContext, context, device, filePath);
+ }
}
}
}