From b7851f9b29dbbe995acf6dc271471e73261c196a Mon Sep 17 00:00:00 2001 From: Sadik Armagan Date: Wed, 6 Oct 2021 16:37:02 +0100 Subject: IVGCVSW-5636 'Implement NNAPI caching functions' * Get number of inputs and outputs from optimized network. * Get number of cached files if backend supports caching. Signed-off-by: Sadik Armagan Change-Id: Ie02ac123bb7df9b0593a2fe46b5bb564a5994780 --- src/backends/cl/ClBackend.hpp | 2 ++ src/backends/cl/ClBackendModelContext.cpp | 11 +++++- src/backends/cl/ClBackendModelContext.hpp | 3 ++ src/backends/cl/ClWorkloadFactory.cpp | 58 ++++++++++++++++++++++++++----- 4 files changed, 65 insertions(+), 9 deletions(-) (limited to 'src/backends/cl') diff --git a/src/backends/cl/ClBackend.hpp b/src/backends/cl/ClBackend.hpp index e0708d18e2..ffce800261 100644 --- a/src/backends/cl/ClBackend.hpp +++ b/src/backends/cl/ClBackend.hpp @@ -105,6 +105,8 @@ public: return m_UsingCustomAllocator; } + virtual unsigned int GetNumberOfCacheFiles() const override { return 1; } + // Cl requires a arm_compute::IAllocator we wrap the Arm NN ICustomAllocator to achieve this class ClBackendCustomAllocatorWrapper : public arm_compute::IAllocator { diff --git a/src/backends/cl/ClBackendModelContext.cpp b/src/backends/cl/ClBackendModelContext.cpp index b685bc296c..75a2e05bda 100644 --- a/src/backends/cl/ClBackendModelContext.cpp +++ b/src/backends/cl/ClBackendModelContext.cpp @@ -32,7 +32,7 @@ namespace armnn { ClBackendModelContext::ClBackendModelContext(const ModelOptions& modelOptions) - : m_CachedNetworkFilePath(""), m_IsFastMathEnabled(false), m_SaveCachedNetwork(false) + : m_CachedNetworkFilePath(""), m_IsFastMathEnabled(false), m_SaveCachedNetwork(false), m_CachedFileDescriptor(-1) { if (!modelOptions.empty()) { @@ -50,6 +50,10 @@ ClBackendModelContext::ClBackendModelContext(const ModelOptions& modelOptions) { m_CachedNetworkFilePath = ParseFile(value, ""); } + if (name == "CachedFileDescriptor") + { + m_CachedFileDescriptor = armnn::ParseIntBackendOption(value, -1); + } }); } } @@ -69,4 +73,9 @@ bool ClBackendModelContext::SaveCachedNetwork() const return m_SaveCachedNetwork; } +int ClBackendModelContext::GetCachedFileDescriptor() const +{ + return m_CachedFileDescriptor; +} + } // namespace armnn \ No newline at end of file diff --git a/src/backends/cl/ClBackendModelContext.hpp b/src/backends/cl/ClBackendModelContext.hpp index e7a26cd688..4feb6aa452 100644 --- a/src/backends/cl/ClBackendModelContext.hpp +++ b/src/backends/cl/ClBackendModelContext.hpp @@ -36,10 +36,13 @@ public: bool SaveCachedNetwork() const; + int GetCachedFileDescriptor() const; + private: std::string m_CachedNetworkFilePath; bool m_IsFastMathEnabled; bool m_SaveCachedNetwork; + int m_CachedFileDescriptor; }; diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp index 2f94ef0970..134dad576e 100644 --- a/src/backends/cl/ClWorkloadFactory.cpp +++ b/src/backends/cl/ClWorkloadFactory.cpp @@ -32,6 +32,8 @@ #include #include +#include + namespace armnn { @@ -67,13 +69,29 @@ void ClWorkloadFactory::AfterWorkloadsCreated() auto modelOptions = dynamic_cast(m_ModelContextPtr.get()); if (modelOptions->SaveCachedNetwork()) { + ClContextSerializer serializer; + serializer.Serialize(m_CLCompileContext); + auto cachedFd = modelOptions->GetCachedFileDescriptor(); + if (cachedFd != -1) + { + std::vector compiledContextData; + std::stringstream stream; + bool serialized = serializer.SaveSerializedToStream(stream); + if (serialized) + { + std::string const serializedString{stream.str()}; + std::copy(serializedString.begin(), + serializedString.end(), + std::back_inserter(compiledContextData)); + write(cachedFd, compiledContextData.data(), compiledContextData.size()); + } + } + // Save map to a filepath provided in ModelOptions auto filePath = modelOptions->GetCachedNetworkFilePath(); if (filePath != "" && fs::exists(filePath) && fs::is_regular_file(filePath)) { // Serialize ClContext to the file specified - ClContextSerializer serializer; - serializer.Serialize(m_CLCompileContext); std::ofstream file(filePath, std::ios::out | std::ios::binary); serializer.SaveSerializedToStream(file); } @@ -123,14 +141,38 @@ void ClWorkloadFactory::InitializeCLCompileContext() // Load saved programs if the user has set a filepath auto modelOptions = dynamic_cast(m_ModelContextPtr.get()); auto filePath = modelOptions->GetCachedNetworkFilePath(); - if (filePath != "" - && fs::exists(filePath) - && fs::is_regular_file(filePath) - && !(modelOptions->SaveCachedNetwork())) + if (!(modelOptions->SaveCachedNetwork())) { - // Deserialize binary file and load into m_CLCompileContext ClContextDeserializer deserializer; - deserializer.Deserialize(m_CLCompileContext, context, device, filePath); + auto cachedFd = modelOptions->GetCachedFileDescriptor(); + if (cachedFd != -1) + { + struct stat statBuffer; + if (fstat(cachedFd, &statBuffer) == 0) + { + long dataSize = static_cast(statBuffer.st_size); + if( dataSize > 0) + { + auto offset = lseek(cachedFd, 0, SEEK_CUR); + if (offset == 0) + { + std::vector compiledContextData(static_cast(dataSize)); + pread(cachedFd, compiledContextData.data(), compiledContextData.size(), 0); + deserializer.DeserializeFromBinary(m_CLCompileContext, + context, + device, + compiledContextData); + } + } + + } + } + + if (filePath != "" && fs::exists(filePath) && fs::is_regular_file(filePath)) + { + // Deserialize binary file and load into m_CLCompileContext + deserializer.Deserialize(m_CLCompileContext, context, device, filePath); + } } } } -- cgit v1.2.1