aboutsummaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-05-07 17:52:36 +0100
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-05-08 20:15:32 +0100
commite5f0b2409c2e557a5a78e2f4659d203154289b23 (patch)
tree0e32680ed15ed5157c78d5deeabda2c0ceeeb4a3 /include
parentae12306486efc55293a40048618abe5e8b19151b (diff)
downloadarmnn-e5f0b2409c2e557a5a78e2f4659d203154289b23.tar.gz
IVGCVSW-5818 Enable import on GPU
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: I4e4eb107aa2bfa09625840d738001f33152e6792
Diffstat (limited to 'include')
-rw-r--r--include/armnn/IRuntime.hpp6
-rw-r--r--include/armnn/backends/IBackendInternal.hpp13
-rw-r--r--include/armnn/backends/ITensorHandleFactory.hpp3
3 files changed, 18 insertions, 4 deletions
diff --git a/include/armnn/IRuntime.hpp b/include/armnn/IRuntime.hpp
index f296a5f564..870e027f33 100644
--- a/include/armnn/IRuntime.hpp
+++ b/include/armnn/IRuntime.hpp
@@ -38,9 +38,9 @@ struct INetworkProperties
, m_ExportEnabled(exportEnabled)
, m_AsyncEnabled(asyncEnabled)
, m_NumThreads(numThreads)
- , m_InputSource(MemorySource::Undefined)
- , m_OutputSource(MemorySource::Undefined)
- {}
+ , m_InputSource(m_ImportEnabled ? MemorySource::Malloc : MemorySource::Undefined)
+ , m_OutputSource(m_ExportEnabled ? MemorySource::Malloc : MemorySource::Undefined)
+ {}
INetworkProperties(bool asyncEnabled,
MemorySource m_InputSource,
diff --git a/include/armnn/backends/IBackendInternal.hpp b/include/armnn/backends/IBackendInternal.hpp
index 8035cff456..135d279c21 100644
--- a/include/armnn/backends/IBackendInternal.hpp
+++ b/include/armnn/backends/IBackendInternal.hpp
@@ -126,6 +126,12 @@ public:
class TensorHandleFactoryRegistry& tensorHandleFactoryRegistry,
const ModelOptions& modelOptions) const;
+ virtual IWorkloadFactoryPtr CreateWorkloadFactory(
+ class TensorHandleFactoryRegistry& tensorHandleFactoryRegistry,
+ const ModelOptions& modelOptions,
+ MemorySourceFlags inputFlags,
+ MemorySourceFlags outputFlags) const;
+
/// Create the runtime context of the backend
///
/// Implementations may return a default-constructed IBackendContextPtr if
@@ -162,6 +168,13 @@ public:
/// IWorkloadFactory::CreateTensor()/IWorkloadFactory::CreateSubtensor() methods must be implemented.
virtual void RegisterTensorHandleFactories(class TensorHandleFactoryRegistry& /*registry*/) {}
+ /// (Optional) Register TensorHandleFactories
+ /// Either this method or CreateMemoryManager() and
+ /// IWorkloadFactory::CreateTensor()/IWorkloadFactory::CreateSubtensor() methods must be implemented.
+ virtual void RegisterTensorHandleFactories(class TensorHandleFactoryRegistry& registry,
+ MemorySourceFlags inputFlags,
+ MemorySourceFlags outputFlags);
+
/// Returns the version of the Backend API
static constexpr BackendVersion GetApiVersion() { return BackendVersion(1, 0); }
diff --git a/include/armnn/backends/ITensorHandleFactory.hpp b/include/armnn/backends/ITensorHandleFactory.hpp
index ae2f44e8c6..501d97b852 100644
--- a/include/armnn/backends/ITensorHandleFactory.hpp
+++ b/include/armnn/backends/ITensorHandleFactory.hpp
@@ -20,6 +20,7 @@ namespace armnn
enum class CapabilityClass
{
PaddingRequired = 1,
+ FallbackImportDisabled = 2,
// add new enum values here
@@ -80,7 +81,7 @@ public:
virtual bool SupportsSubTensors() const = 0;
- virtual bool SupportsMapUnmap() const final { return true; }
+ virtual bool SupportsMapUnmap() const { return true; }
virtual MemorySourceFlags GetExportFlags() const { return 0; }
virtual MemorySourceFlags GetImportFlags() const { return 0; }