diff options
-rw-r--r-- | include/armnn/Exceptions.hpp | 9 | ||||
-rw-r--r-- | include/armnn/backends/IBackendInternal.hpp | 9 | ||||
-rw-r--r-- | src/armnn/DeviceSpec.hpp | 4 | ||||
-rw-r--r-- | src/armnn/Runtime.cpp | 13 | ||||
-rw-r--r-- | src/armnn/Runtime.hpp | 2 |
5 files changed, 28 insertions, 9 deletions
diff --git a/include/armnn/Exceptions.hpp b/include/armnn/Exceptions.hpp index e21e974fc7..066f59f792 100644 --- a/include/armnn/Exceptions.hpp +++ b/include/armnn/Exceptions.hpp @@ -64,12 +64,19 @@ private: std::string m_Message; }; -class ClRuntimeUnavailableException : public Exception +/// Class for non-fatal exceptions raised while initialising a backend +class BackendUnavailableException : public Exception { public: using Exception::Exception; }; +class ClRuntimeUnavailableException : public BackendUnavailableException +{ +public: + using BackendUnavailableException::BackendUnavailableException; +}; + class InvalidArgumentException : public Exception { public: diff --git a/include/armnn/backends/IBackendInternal.hpp b/include/armnn/backends/IBackendInternal.hpp index 29097b4ae7..6c919ee5d4 100644 --- a/include/armnn/backends/IBackendInternal.hpp +++ b/include/armnn/backends/IBackendInternal.hpp @@ -115,9 +115,16 @@ public: virtual IWorkloadFactoryPtr CreateWorkloadFactory( class TensorHandleFactoryRegistry& tensorHandleFactoryRegistry) const; + /// Create the runtime context of the backend + /// + /// Implementations may return a default-constructed IBackendContextPtr if + /// no context is needed at runtime. + /// Implementations must throw BackendUnavailableException if the backend + /// cannot be used (for example, necessary accelerator hardware is not present). + /// The default implementation always returns a default-constructed pointer. virtual IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const; - // Context specifically used for profiling interaction from backends. + /// Create context specifically used for profiling interaction from backends. virtual IBackendProfilingContextPtr CreateBackendProfilingContext(const IRuntime::CreationOptions& creationOptions, armnn::profiling::IBackendProfiling& backendProfiling) const; diff --git a/src/armnn/DeviceSpec.hpp b/src/armnn/DeviceSpec.hpp index 703a4b123f..a1457cf80e 100644 --- a/src/armnn/DeviceSpec.hpp +++ b/src/armnn/DeviceSpec.hpp @@ -14,6 +14,9 @@ namespace armnn class DeviceSpec : public IDeviceSpec { public: + DeviceSpec() + {} + DeviceSpec(const BackendIdSet& supportedBackends) : m_SupportedBackends{supportedBackends} {} @@ -48,7 +51,6 @@ public: } private: - DeviceSpec() = delete; BackendIdSet m_SupportedBackends; BackendIdSet m_DynamicBackends; }; diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp index 2d7269a09c..47c998a6fd 100644 --- a/src/armnn/Runtime.cpp +++ b/src/armnn/Runtime.cpp @@ -153,7 +153,6 @@ const std::shared_ptr<IProfiler> Runtime::GetProfiler(NetworkId networkId) const Runtime::Runtime(const CreationOptions& options) : m_NetworkIdCounter(0) - , m_DeviceSpec{BackendRegistryInstance().GetBackendIds()} { ARMNN_LOG(info) << "ArmNN v" << ARMNN_VERSION << "\n"; @@ -164,12 +163,11 @@ Runtime::Runtime(const CreationOptions& options) // goes through the backend registry LoadDynamicBackends(options.m_DynamicBackendsPath); + BackendIdSet supportedBackends; for (const auto& id : BackendRegistryInstance().GetBackendIds()) { // Store backend contexts for the supported ones - const BackendIdSet& supportedBackends = m_DeviceSpec.GetSupportedBackends(); - if (supportedBackends.find(id) != supportedBackends.end()) - { + try { auto factoryFun = BackendRegistryInstance().GetFactory(id); auto backend = factoryFun(); BOOST_ASSERT(backend.get() != nullptr); @@ -182,8 +180,15 @@ Runtime::Runtime(const CreationOptions& options) { m_BackendContexts.emplace(std::make_pair(id, std::move(context))); } + supportedBackends.emplace(id); + } + catch (const BackendUnavailableException&) + { + // Ignore backends which are unavailable } + } + m_DeviceSpec.AddSupportedBackends(supportedBackends); } Runtime::~Runtime() diff --git a/src/armnn/Runtime.hpp b/src/armnn/Runtime.hpp index e5debbf9ac..2ad3c9633c 100644 --- a/src/armnn/Runtime.hpp +++ b/src/armnn/Runtime.hpp @@ -73,8 +73,6 @@ public: virtual void RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func) override; /// Creates a runtime for workload execution. - /// May throw a ClRuntimeUnavailableException if @a defaultComputeDevice requires a CL runtime but - /// it cannot be setup for some reason. Runtime(const CreationOptions& options); ~Runtime(); |