From b7851f9b29dbbe995acf6dc271471e73261c196a Mon Sep 17 00:00:00 2001 From: Sadik Armagan Date: Wed, 6 Oct 2021 16:37:02 +0100 Subject: IVGCVSW-5636 'Implement NNAPI caching functions' * Get number of inputs and outputs from optimized network. * Get number of cached files if backend supports caching. Signed-off-by: Sadik Armagan Change-Id: Ie02ac123bb7df9b0593a2fe46b5bb564a5994780 --- include/armnn/BackendHelper.hpp | 3 ++ include/armnn/BackendOptions.hpp | 9 +++++ include/armnn/INetwork.hpp | 3 ++ include/armnn/backends/IBackendInternal.hpp | 5 +++ src/armnn/BackendHelper.cpp | 12 ++++++ src/armnn/Network.cpp | 20 ++++++++++ src/armnn/OptimizedNetworkImpl.hpp | 3 ++ src/backends/cl/ClBackend.hpp | 2 + src/backends/cl/ClBackendModelContext.cpp | 11 +++++- src/backends/cl/ClBackendModelContext.hpp | 3 ++ src/backends/cl/ClWorkloadFactory.cpp | 58 +++++++++++++++++++++++++---- 11 files changed, 120 insertions(+), 9 deletions(-) diff --git a/include/armnn/BackendHelper.hpp b/include/armnn/BackendHelper.hpp index 0bd37dcf29..03731ac24a 100644 --- a/include/armnn/BackendHelper.hpp +++ b/include/armnn/BackendHelper.hpp @@ -448,4 +448,7 @@ Optional GetCapability(const std::string& b ARMNN_DEPRECATED_MSG_REMOVAL_DATE("This function has been deprecated in favour of GetBackendCapability", "22.05") bool IsCapabilitySupported(const armnn::BackendId& backend, armnn::BackendCapability capability); +/// Returns the number of cached files if backend supports caching +unsigned int GetNumberOfCacheFiles(const armnn::BackendId& backend); + } diff --git a/include/armnn/BackendOptions.hpp b/include/armnn/BackendOptions.hpp index 33cecf6614..e5694493d3 100644 --- a/include/armnn/BackendOptions.hpp +++ b/include/armnn/BackendOptions.hpp @@ -314,4 +314,13 @@ inline std::string ParseStringBackendOption(const armnn::BackendOptions::Var& va return defaultValue; } +inline int ParseIntBackendOption(const armnn::BackendOptions::Var& value, int defaultValue) +{ + if (value.IsInt()) + { + return value.AsInt(); + } + return defaultValue; +} + } //namespace armnn diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp index 707ae00bb3..f85b29ee81 100644 --- a/include/armnn/INetwork.hpp +++ b/include/armnn/INetwork.hpp @@ -722,6 +722,9 @@ public: profiling::ProfilingGuid GetGuid() const; + size_t GetNumInputs() const; + size_t GetNumOutputs() const; + // Creates a copy of the IOptimizedNetwork. The IOptimizedNetwork will not be reoptimized, // the provided ModelOptions will only be used when creating a LoadedNetwork. IOptimizedNetwork(const IOptimizedNetwork& other, const ModelOptions& modelOptions); diff --git a/include/armnn/backends/IBackendInternal.hpp b/include/armnn/backends/IBackendInternal.hpp index 7500e35897..9253e11d4f 100644 --- a/include/armnn/backends/IBackendInternal.hpp +++ b/include/armnn/backends/IBackendInternal.hpp @@ -206,6 +206,11 @@ public: { throw armnn::Exception("GetDefaultAllocator: Function has not been implemented in backend."); } + + /// Returns the number of files cached if backend supports caching + /// + /// \return - Returns 0 if backend does not support caching otherwise number of files cached + virtual unsigned int GetNumberOfCacheFiles() const { return 0; } }; using IBackendInternalUniquePtr = std::unique_ptr; diff --git a/src/armnn/BackendHelper.cpp b/src/armnn/BackendHelper.cpp index cc792a06ef..c3cebddb2b 100644 --- a/src/armnn/BackendHelper.cpp +++ b/src/armnn/BackendHelper.cpp @@ -126,6 +126,18 @@ bool IsCapabilitySupported(const armnn::BackendId& backend, armnn::BackendCapabi return hasCapability; } +unsigned int GetNumberOfCacheFiles(const armnn::BackendId& backend) +{ + auto const& backendRegistry = armnn::BackendRegistryInstance(); + if (backendRegistry.IsBackendRegistered(backend)) + { + auto factoryFunc = backendRegistry.GetFactory(backend); + auto backendObject = factoryFunc(); + return backendObject->GetNumberOfCacheFiles(); + } + return 0; +} + bool LayerSupportHandle::IsBackendRegistered() const { if (m_LayerSupport) diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index b516d519d5..e00dbfc0fc 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -524,6 +524,16 @@ profiling::ProfilingGuid IOptimizedNetwork::GetGuid() const return pOptimizedNetworkImpl->GetGuid(); } +size_t IOptimizedNetwork::GetNumInputs() const +{ + return pOptimizedNetworkImpl->GetNumInputs(); +} + +size_t IOptimizedNetwork::GetNumOutputs() const +{ + return pOptimizedNetworkImpl->GetNumOutputs(); +} + Status OptimizedNetworkImpl::PrintGraph() { m_Graph->Print(); @@ -535,6 +545,16 @@ Status OptimizedNetworkImpl::SerializeToDot(std::ostream& stream) const return m_Graph->SerializeToDot(stream); } +size_t OptimizedNetworkImpl::GetNumInputs() const +{ + return m_Graph->GetNumInputs(); +} + +size_t OptimizedNetworkImpl::GetNumOutputs() const +{ + return m_Graph->GetNumOutputs(); +} + void ReportError(const std::string& errorMessage, Optional&> errorMessages) { diff --git a/src/armnn/OptimizedNetworkImpl.hpp b/src/armnn/OptimizedNetworkImpl.hpp index d42cff7346..112d585aee 100644 --- a/src/armnn/OptimizedNetworkImpl.hpp +++ b/src/armnn/OptimizedNetworkImpl.hpp @@ -21,6 +21,9 @@ public: virtual profiling::ProfilingGuid GetGuid() const { return m_Guid; }; + virtual size_t GetNumInputs() const; + virtual size_t GetNumOutputs() const; + Graph& GetGraph() { return *m_Graph; } ModelOptions& GetModelOptions() { return m_ModelOptions; } diff --git a/src/backends/cl/ClBackend.hpp b/src/backends/cl/ClBackend.hpp index e0708d18e2..ffce800261 100644 --- a/src/backends/cl/ClBackend.hpp +++ b/src/backends/cl/ClBackend.hpp @@ -105,6 +105,8 @@ public: return m_UsingCustomAllocator; } + virtual unsigned int GetNumberOfCacheFiles() const override { return 1; } + // Cl requires a arm_compute::IAllocator we wrap the Arm NN ICustomAllocator to achieve this class ClBackendCustomAllocatorWrapper : public arm_compute::IAllocator { diff --git a/src/backends/cl/ClBackendModelContext.cpp b/src/backends/cl/ClBackendModelContext.cpp index b685bc296c..75a2e05bda 100644 --- a/src/backends/cl/ClBackendModelContext.cpp +++ b/src/backends/cl/ClBackendModelContext.cpp @@ -32,7 +32,7 @@ namespace armnn { ClBackendModelContext::ClBackendModelContext(const ModelOptions& modelOptions) - : m_CachedNetworkFilePath(""), m_IsFastMathEnabled(false), m_SaveCachedNetwork(false) + : m_CachedNetworkFilePath(""), m_IsFastMathEnabled(false), m_SaveCachedNetwork(false), m_CachedFileDescriptor(-1) { if (!modelOptions.empty()) { @@ -50,6 +50,10 @@ ClBackendModelContext::ClBackendModelContext(const ModelOptions& modelOptions) { m_CachedNetworkFilePath = ParseFile(value, ""); } + if (name == "CachedFileDescriptor") + { + m_CachedFileDescriptor = armnn::ParseIntBackendOption(value, -1); + } }); } } @@ -69,4 +73,9 @@ bool ClBackendModelContext::SaveCachedNetwork() const return m_SaveCachedNetwork; } +int ClBackendModelContext::GetCachedFileDescriptor() const +{ + return m_CachedFileDescriptor; +} + } // namespace armnn \ No newline at end of file diff --git a/src/backends/cl/ClBackendModelContext.hpp b/src/backends/cl/ClBackendModelContext.hpp index e7a26cd688..4feb6aa452 100644 --- a/src/backends/cl/ClBackendModelContext.hpp +++ b/src/backends/cl/ClBackendModelContext.hpp @@ -36,10 +36,13 @@ public: bool SaveCachedNetwork() const; + int GetCachedFileDescriptor() const; + private: std::string m_CachedNetworkFilePath; bool m_IsFastMathEnabled; bool m_SaveCachedNetwork; + int m_CachedFileDescriptor; }; diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp index 2f94ef0970..134dad576e 100644 --- a/src/backends/cl/ClWorkloadFactory.cpp +++ b/src/backends/cl/ClWorkloadFactory.cpp @@ -32,6 +32,8 @@ #include #include +#include + namespace armnn { @@ -67,13 +69,29 @@ void ClWorkloadFactory::AfterWorkloadsCreated() auto modelOptions = dynamic_cast(m_ModelContextPtr.get()); if (modelOptions->SaveCachedNetwork()) { + ClContextSerializer serializer; + serializer.Serialize(m_CLCompileContext); + auto cachedFd = modelOptions->GetCachedFileDescriptor(); + if (cachedFd != -1) + { + std::vector compiledContextData; + std::stringstream stream; + bool serialized = serializer.SaveSerializedToStream(stream); + if (serialized) + { + std::string const serializedString{stream.str()}; + std::copy(serializedString.begin(), + serializedString.end(), + std::back_inserter(compiledContextData)); + write(cachedFd, compiledContextData.data(), compiledContextData.size()); + } + } + // Save map to a filepath provided in ModelOptions auto filePath = modelOptions->GetCachedNetworkFilePath(); if (filePath != "" && fs::exists(filePath) && fs::is_regular_file(filePath)) { // Serialize ClContext to the file specified - ClContextSerializer serializer; - serializer.Serialize(m_CLCompileContext); std::ofstream file(filePath, std::ios::out | std::ios::binary); serializer.SaveSerializedToStream(file); } @@ -123,14 +141,38 @@ void ClWorkloadFactory::InitializeCLCompileContext() // Load saved programs if the user has set a filepath auto modelOptions = dynamic_cast(m_ModelContextPtr.get()); auto filePath = modelOptions->GetCachedNetworkFilePath(); - if (filePath != "" - && fs::exists(filePath) - && fs::is_regular_file(filePath) - && !(modelOptions->SaveCachedNetwork())) + if (!(modelOptions->SaveCachedNetwork())) { - // Deserialize binary file and load into m_CLCompileContext ClContextDeserializer deserializer; - deserializer.Deserialize(m_CLCompileContext, context, device, filePath); + auto cachedFd = modelOptions->GetCachedFileDescriptor(); + if (cachedFd != -1) + { + struct stat statBuffer; + if (fstat(cachedFd, &statBuffer) == 0) + { + long dataSize = static_cast(statBuffer.st_size); + if( dataSize > 0) + { + auto offset = lseek(cachedFd, 0, SEEK_CUR); + if (offset == 0) + { + std::vector compiledContextData(static_cast(dataSize)); + pread(cachedFd, compiledContextData.data(), compiledContextData.size(), 0); + deserializer.DeserializeFromBinary(m_CLCompileContext, + context, + device, + compiledContextData); + } + } + + } + } + + if (filePath != "" && fs::exists(filePath) && fs::is_regular_file(filePath)) + { + // Deserialize binary file and load into m_CLCompileContext + deserializer.Deserialize(m_CLCompileContext, context, device, filePath); + } } } } -- cgit v1.2.1