aboutsummaryrefslogtreecommitdiff
path: root/src/backends/cl/ClWorkloadFactory.cpp
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2021-10-06 16:37:02 +0100
committerSadik Armagan <sadik.armagan@arm.com>2021-10-28 08:17:26 +0000
commitb7851f9b29dbbe995acf6dc271471e73261c196a (patch)
treeb00b29ea8e4161b11473053a0957f4a902aca837 /src/backends/cl/ClWorkloadFactory.cpp
parentaf3a4ef77d8f330a995911b979417857514df62c (diff)
downloadarmnn-b7851f9b29dbbe995acf6dc271471e73261c196a.tar.gz
IVGCVSW-5636 'Implement NNAPI caching functions'
* Get number of inputs and outputs from optimized network. * Get number of cached files if backend supports caching. Signed-off-by: Sadik Armagan <sadik.armagan@arm.com> Change-Id: Ie02ac123bb7df9b0593a2fe46b5bb564a5994780
Diffstat (limited to 'src/backends/cl/ClWorkloadFactory.cpp')
-rw-r--r--src/backends/cl/ClWorkloadFactory.cpp58
1 files changed, 50 insertions, 8 deletions
diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp
index 2f94ef0970..134dad576e 100644
--- a/src/backends/cl/ClWorkloadFactory.cpp
+++ b/src/backends/cl/ClWorkloadFactory.cpp
@@ -32,6 +32,8 @@
#include <armnnUtils/Filesystem.hpp>
#include <fstream>
+#include <sys/stat.h>
+
namespace armnn
{
@@ -67,13 +69,29 @@ void ClWorkloadFactory::AfterWorkloadsCreated()
auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get());
if (modelOptions->SaveCachedNetwork())
{
+ ClContextSerializer serializer;
+ serializer.Serialize(m_CLCompileContext);
+ auto cachedFd = modelOptions->GetCachedFileDescriptor();
+ if (cachedFd != -1)
+ {
+ std::vector<uint8_t> compiledContextData;
+ std::stringstream stream;
+ bool serialized = serializer.SaveSerializedToStream(stream);
+ if (serialized)
+ {
+ std::string const serializedString{stream.str()};
+ std::copy(serializedString.begin(),
+ serializedString.end(),
+ std::back_inserter(compiledContextData));
+ write(cachedFd, compiledContextData.data(), compiledContextData.size());
+ }
+ }
+
// Save map to a filepath provided in ModelOptions
auto filePath = modelOptions->GetCachedNetworkFilePath();
if (filePath != "" && fs::exists(filePath) && fs::is_regular_file(filePath))
{
// Serialize ClContext to the file specified
- ClContextSerializer serializer;
- serializer.Serialize(m_CLCompileContext);
std::ofstream file(filePath, std::ios::out | std::ios::binary);
serializer.SaveSerializedToStream(file);
}
@@ -123,14 +141,38 @@ void ClWorkloadFactory::InitializeCLCompileContext()
// Load saved programs if the user has set a filepath
auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get());
auto filePath = modelOptions->GetCachedNetworkFilePath();
- if (filePath != ""
- && fs::exists(filePath)
- && fs::is_regular_file(filePath)
- && !(modelOptions->SaveCachedNetwork()))
+ if (!(modelOptions->SaveCachedNetwork()))
{
- // Deserialize binary file and load into m_CLCompileContext
ClContextDeserializer deserializer;
- deserializer.Deserialize(m_CLCompileContext, context, device, filePath);
+ auto cachedFd = modelOptions->GetCachedFileDescriptor();
+ if (cachedFd != -1)
+ {
+ struct stat statBuffer;
+ if (fstat(cachedFd, &statBuffer) == 0)
+ {
+ long dataSize = static_cast<long>(statBuffer.st_size);
+ if( dataSize > 0)
+ {
+ auto offset = lseek(cachedFd, 0, SEEK_CUR);
+ if (offset == 0)
+ {
+ std::vector <uint8_t> compiledContextData(static_cast<unsigned int>(dataSize));
+ pread(cachedFd, compiledContextData.data(), compiledContextData.size(), 0);
+ deserializer.DeserializeFromBinary(m_CLCompileContext,
+ context,
+ device,
+ compiledContextData);
+ }
+ }
+
+ }
+ }
+
+ if (filePath != "" && fs::exists(filePath) && fs::is_regular_file(filePath))
+ {
+ // Deserialize binary file and load into m_CLCompileContext
+ deserializer.Deserialize(m_CLCompileContext, context, device, filePath);
+ }
}
}
}