diff options
Diffstat (limited to 'include')
-rw-r--r-- | include/armnn/IRuntime.hpp | 6 | ||||
-rw-r--r-- | include/armnn/backends/IBackendInternal.hpp | 13 | ||||
-rw-r--r-- | include/armnn/backends/ITensorHandleFactory.hpp | 3 |
3 files changed, 18 insertions, 4 deletions
diff --git a/include/armnn/IRuntime.hpp b/include/armnn/IRuntime.hpp index f296a5f564..870e027f33 100644 --- a/include/armnn/IRuntime.hpp +++ b/include/armnn/IRuntime.hpp @@ -38,9 +38,9 @@ struct INetworkProperties , m_ExportEnabled(exportEnabled) , m_AsyncEnabled(asyncEnabled) , m_NumThreads(numThreads) - , m_InputSource(MemorySource::Undefined) - , m_OutputSource(MemorySource::Undefined) - {} + , m_InputSource(m_ImportEnabled ? MemorySource::Malloc : MemorySource::Undefined) + , m_OutputSource(m_ExportEnabled ? MemorySource::Malloc : MemorySource::Undefined) + {} INetworkProperties(bool asyncEnabled, MemorySource m_InputSource, diff --git a/include/armnn/backends/IBackendInternal.hpp b/include/armnn/backends/IBackendInternal.hpp index 8035cff456..135d279c21 100644 --- a/include/armnn/backends/IBackendInternal.hpp +++ b/include/armnn/backends/IBackendInternal.hpp @@ -126,6 +126,12 @@ public: class TensorHandleFactoryRegistry& tensorHandleFactoryRegistry, const ModelOptions& modelOptions) const; + virtual IWorkloadFactoryPtr CreateWorkloadFactory( + class TensorHandleFactoryRegistry& tensorHandleFactoryRegistry, + const ModelOptions& modelOptions, + MemorySourceFlags inputFlags, + MemorySourceFlags outputFlags) const; + /// Create the runtime context of the backend /// /// Implementations may return a default-constructed IBackendContextPtr if @@ -162,6 +168,13 @@ public: /// IWorkloadFactory::CreateTensor()/IWorkloadFactory::CreateSubtensor() methods must be implemented. virtual void RegisterTensorHandleFactories(class TensorHandleFactoryRegistry& /*registry*/) {} + /// (Optional) Register TensorHandleFactories + /// Either this method or CreateMemoryManager() and + /// IWorkloadFactory::CreateTensor()/IWorkloadFactory::CreateSubtensor() methods must be implemented. + virtual void RegisterTensorHandleFactories(class TensorHandleFactoryRegistry& registry, + MemorySourceFlags inputFlags, + MemorySourceFlags outputFlags); + /// Returns the version of the Backend API static constexpr BackendVersion GetApiVersion() { return BackendVersion(1, 0); } diff --git a/include/armnn/backends/ITensorHandleFactory.hpp b/include/armnn/backends/ITensorHandleFactory.hpp index ae2f44e8c6..501d97b852 100644 --- a/include/armnn/backends/ITensorHandleFactory.hpp +++ b/include/armnn/backends/ITensorHandleFactory.hpp @@ -20,6 +20,7 @@ namespace armnn enum class CapabilityClass { PaddingRequired = 1, + FallbackImportDisabled = 2, // add new enum values here @@ -80,7 +81,7 @@ public: virtual bool SupportsSubTensors() const = 0; - virtual bool SupportsMapUnmap() const final { return true; } + virtual bool SupportsMapUnmap() const { return true; } virtual MemorySourceFlags GetExportFlags() const { return 0; } virtual MemorySourceFlags GetImportFlags() const { return 0; } |