aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2021-10-06 16:37:02 +0100
committerSadik Armagan <sadik.armagan@arm.com>2021-10-28 08:17:26 +0000
commitb7851f9b29dbbe995acf6dc271471e73261c196a (patch)
treeb00b29ea8e4161b11473053a0957f4a902aca837
parentaf3a4ef77d8f330a995911b979417857514df62c (diff)
downloadarmnn-b7851f9b29dbbe995acf6dc271471e73261c196a.tar.gz
IVGCVSW-5636 'Implement NNAPI caching functions'
* Get number of inputs and outputs from optimized network. * Get number of cached files if backend supports caching. Signed-off-by: Sadik Armagan <sadik.armagan@arm.com> Change-Id: Ie02ac123bb7df9b0593a2fe46b5bb564a5994780
-rw-r--r--include/armnn/BackendHelper.hpp3
-rw-r--r--include/armnn/BackendOptions.hpp9
-rw-r--r--include/armnn/INetwork.hpp3
-rw-r--r--include/armnn/backends/IBackendInternal.hpp5
-rw-r--r--src/armnn/BackendHelper.cpp12
-rw-r--r--src/armnn/Network.cpp20
-rw-r--r--src/armnn/OptimizedNetworkImpl.hpp3
-rw-r--r--src/backends/cl/ClBackend.hpp2
-rw-r--r--src/backends/cl/ClBackendModelContext.cpp11
-rw-r--r--src/backends/cl/ClBackendModelContext.hpp3
-rw-r--r--src/backends/cl/ClWorkloadFactory.cpp58
11 files changed, 120 insertions, 9 deletions
diff --git a/include/armnn/BackendHelper.hpp b/include/armnn/BackendHelper.hpp
index 0bd37dcf29..03731ac24a 100644
--- a/include/armnn/BackendHelper.hpp
+++ b/include/armnn/BackendHelper.hpp
@@ -448,4 +448,7 @@ Optional<const BackendOptions::BackendOption> GetCapability(const std::string& b
ARMNN_DEPRECATED_MSG_REMOVAL_DATE("This function has been deprecated in favour of GetBackendCapability", "22.05")
bool IsCapabilitySupported(const armnn::BackendId& backend, armnn::BackendCapability capability);
+/// Returns the number of cached files if backend supports caching
+unsigned int GetNumberOfCacheFiles(const armnn::BackendId& backend);
+
}
diff --git a/include/armnn/BackendOptions.hpp b/include/armnn/BackendOptions.hpp
index 33cecf6614..e5694493d3 100644
--- a/include/armnn/BackendOptions.hpp
+++ b/include/armnn/BackendOptions.hpp
@@ -314,4 +314,13 @@ inline std::string ParseStringBackendOption(const armnn::BackendOptions::Var& va
return defaultValue;
}
+inline int ParseIntBackendOption(const armnn::BackendOptions::Var& value, int defaultValue)
+{
+ if (value.IsInt())
+ {
+ return value.AsInt();
+ }
+ return defaultValue;
+}
+
} //namespace armnn
diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp
index 707ae00bb3..f85b29ee81 100644
--- a/include/armnn/INetwork.hpp
+++ b/include/armnn/INetwork.hpp
@@ -722,6 +722,9 @@ public:
profiling::ProfilingGuid GetGuid() const;
+ size_t GetNumInputs() const;
+ size_t GetNumOutputs() const;
+
// Creates a copy of the IOptimizedNetwork. The IOptimizedNetwork will not be reoptimized,
// the provided ModelOptions will only be used when creating a LoadedNetwork.
IOptimizedNetwork(const IOptimizedNetwork& other, const ModelOptions& modelOptions);
diff --git a/include/armnn/backends/IBackendInternal.hpp b/include/armnn/backends/IBackendInternal.hpp
index 7500e35897..9253e11d4f 100644
--- a/include/armnn/backends/IBackendInternal.hpp
+++ b/include/armnn/backends/IBackendInternal.hpp
@@ -206,6 +206,11 @@ public:
{
throw armnn::Exception("GetDefaultAllocator: Function has not been implemented in backend.");
}
+
+ /// Returns the number of files cached if backend supports caching
+ ///
+ /// \return - Returns 0 if backend does not support caching otherwise number of files cached
+ virtual unsigned int GetNumberOfCacheFiles() const { return 0; }
};
using IBackendInternalUniquePtr = std::unique_ptr<IBackendInternal>;
diff --git a/src/armnn/BackendHelper.cpp b/src/armnn/BackendHelper.cpp
index cc792a06ef..c3cebddb2b 100644
--- a/src/armnn/BackendHelper.cpp
+++ b/src/armnn/BackendHelper.cpp
@@ -126,6 +126,18 @@ bool IsCapabilitySupported(const armnn::BackendId& backend, armnn::BackendCapabi
return hasCapability;
}
+unsigned int GetNumberOfCacheFiles(const armnn::BackendId& backend)
+{
+ auto const& backendRegistry = armnn::BackendRegistryInstance();
+ if (backendRegistry.IsBackendRegistered(backend))
+ {
+ auto factoryFunc = backendRegistry.GetFactory(backend);
+ auto backendObject = factoryFunc();
+ return backendObject->GetNumberOfCacheFiles();
+ }
+ return 0;
+}
+
bool LayerSupportHandle::IsBackendRegistered() const
{
if (m_LayerSupport)
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp
index b516d519d5..e00dbfc0fc 100644
--- a/src/armnn/Network.cpp
+++ b/src/armnn/Network.cpp
@@ -524,6 +524,16 @@ profiling::ProfilingGuid IOptimizedNetwork::GetGuid() const
return pOptimizedNetworkImpl->GetGuid();
}
+size_t IOptimizedNetwork::GetNumInputs() const
+{
+ return pOptimizedNetworkImpl->GetNumInputs();
+}
+
+size_t IOptimizedNetwork::GetNumOutputs() const
+{
+ return pOptimizedNetworkImpl->GetNumOutputs();
+}
+
Status OptimizedNetworkImpl::PrintGraph()
{
m_Graph->Print();
@@ -535,6 +545,16 @@ Status OptimizedNetworkImpl::SerializeToDot(std::ostream& stream) const
return m_Graph->SerializeToDot(stream);
}
+size_t OptimizedNetworkImpl::GetNumInputs() const
+{
+ return m_Graph->GetNumInputs();
+}
+
+size_t OptimizedNetworkImpl::GetNumOutputs() const
+{
+ return m_Graph->GetNumOutputs();
+}
+
void ReportError(const std::string& errorMessage,
Optional<std::vector<std::string>&> errorMessages)
{
diff --git a/src/armnn/OptimizedNetworkImpl.hpp b/src/armnn/OptimizedNetworkImpl.hpp
index d42cff7346..112d585aee 100644
--- a/src/armnn/OptimizedNetworkImpl.hpp
+++ b/src/armnn/OptimizedNetworkImpl.hpp
@@ -21,6 +21,9 @@ public:
virtual profiling::ProfilingGuid GetGuid() const { return m_Guid; };
+ virtual size_t GetNumInputs() const;
+ virtual size_t GetNumOutputs() const;
+
Graph& GetGraph() { return *m_Graph; }
ModelOptions& GetModelOptions() { return m_ModelOptions; }
diff --git a/src/backends/cl/ClBackend.hpp b/src/backends/cl/ClBackend.hpp
index e0708d18e2..ffce800261 100644
--- a/src/backends/cl/ClBackend.hpp
+++ b/src/backends/cl/ClBackend.hpp
@@ -105,6 +105,8 @@ public:
return m_UsingCustomAllocator;
}
+ virtual unsigned int GetNumberOfCacheFiles() const override { return 1; }
+
// Cl requires a arm_compute::IAllocator we wrap the Arm NN ICustomAllocator to achieve this
class ClBackendCustomAllocatorWrapper : public arm_compute::IAllocator
{
diff --git a/src/backends/cl/ClBackendModelContext.cpp b/src/backends/cl/ClBackendModelContext.cpp
index b685bc296c..75a2e05bda 100644
--- a/src/backends/cl/ClBackendModelContext.cpp
+++ b/src/backends/cl/ClBackendModelContext.cpp
@@ -32,7 +32,7 @@ namespace armnn
{
ClBackendModelContext::ClBackendModelContext(const ModelOptions& modelOptions)
- : m_CachedNetworkFilePath(""), m_IsFastMathEnabled(false), m_SaveCachedNetwork(false)
+ : m_CachedNetworkFilePath(""), m_IsFastMathEnabled(false), m_SaveCachedNetwork(false), m_CachedFileDescriptor(-1)
{
if (!modelOptions.empty())
{
@@ -50,6 +50,10 @@ ClBackendModelContext::ClBackendModelContext(const ModelOptions& modelOptions)
{
m_CachedNetworkFilePath = ParseFile(value, "");
}
+ if (name == "CachedFileDescriptor")
+ {
+ m_CachedFileDescriptor = armnn::ParseIntBackendOption(value, -1);
+ }
});
}
}
@@ -69,4 +73,9 @@ bool ClBackendModelContext::SaveCachedNetwork() const
return m_SaveCachedNetwork;
}
+int ClBackendModelContext::GetCachedFileDescriptor() const
+{
+ return m_CachedFileDescriptor;
+}
+
} // namespace armnn \ No newline at end of file
diff --git a/src/backends/cl/ClBackendModelContext.hpp b/src/backends/cl/ClBackendModelContext.hpp
index e7a26cd688..4feb6aa452 100644
--- a/src/backends/cl/ClBackendModelContext.hpp
+++ b/src/backends/cl/ClBackendModelContext.hpp
@@ -36,10 +36,13 @@ public:
bool SaveCachedNetwork() const;
+ int GetCachedFileDescriptor() const;
+
private:
std::string m_CachedNetworkFilePath;
bool m_IsFastMathEnabled;
bool m_SaveCachedNetwork;
+ int m_CachedFileDescriptor;
};
diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp
index 2f94ef0970..134dad576e 100644
--- a/src/backends/cl/ClWorkloadFactory.cpp
+++ b/src/backends/cl/ClWorkloadFactory.cpp
@@ -32,6 +32,8 @@
#include <armnnUtils/Filesystem.hpp>
#include <fstream>
+#include <sys/stat.h>
+
namespace armnn
{
@@ -67,13 +69,29 @@ void ClWorkloadFactory::AfterWorkloadsCreated()
auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get());
if (modelOptions->SaveCachedNetwork())
{
+ ClContextSerializer serializer;
+ serializer.Serialize(m_CLCompileContext);
+ auto cachedFd = modelOptions->GetCachedFileDescriptor();
+ if (cachedFd != -1)
+ {
+ std::vector<uint8_t> compiledContextData;
+ std::stringstream stream;
+ bool serialized = serializer.SaveSerializedToStream(stream);
+ if (serialized)
+ {
+ std::string const serializedString{stream.str()};
+ std::copy(serializedString.begin(),
+ serializedString.end(),
+ std::back_inserter(compiledContextData));
+ write(cachedFd, compiledContextData.data(), compiledContextData.size());
+ }
+ }
+
// Save map to a filepath provided in ModelOptions
auto filePath = modelOptions->GetCachedNetworkFilePath();
if (filePath != "" && fs::exists(filePath) && fs::is_regular_file(filePath))
{
// Serialize ClContext to the file specified
- ClContextSerializer serializer;
- serializer.Serialize(m_CLCompileContext);
std::ofstream file(filePath, std::ios::out | std::ios::binary);
serializer.SaveSerializedToStream(file);
}
@@ -123,14 +141,38 @@ void ClWorkloadFactory::InitializeCLCompileContext()
// Load saved programs if the user has set a filepath
auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get());
auto filePath = modelOptions->GetCachedNetworkFilePath();
- if (filePath != ""
- && fs::exists(filePath)
- && fs::is_regular_file(filePath)
- && !(modelOptions->SaveCachedNetwork()))
+ if (!(modelOptions->SaveCachedNetwork()))
{
- // Deserialize binary file and load into m_CLCompileContext
ClContextDeserializer deserializer;
- deserializer.Deserialize(m_CLCompileContext, context, device, filePath);
+ auto cachedFd = modelOptions->GetCachedFileDescriptor();
+ if (cachedFd != -1)
+ {
+ struct stat statBuffer;
+ if (fstat(cachedFd, &statBuffer) == 0)
+ {
+ long dataSize = static_cast<long>(statBuffer.st_size);
+ if( dataSize > 0)
+ {
+ auto offset = lseek(cachedFd, 0, SEEK_CUR);
+ if (offset == 0)
+ {
+ std::vector <uint8_t> compiledContextData(static_cast<unsigned int>(dataSize));
+ pread(cachedFd, compiledContextData.data(), compiledContextData.size(), 0);
+ deserializer.DeserializeFromBinary(m_CLCompileContext,
+ context,
+ device,
+ compiledContextData);
+ }
+ }
+
+ }
+ }
+
+ if (filePath != "" && fs::exists(filePath) && fs::is_regular_file(filePath))
+ {
+ // Deserialize binary file and load into m_CLCompileContext
+ deserializer.Deserialize(m_CLCompileContext, context, device, filePath);
+ }
}
}
}