diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadFactory.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadFactory.cpp | 22 |
1 files changed, 18 insertions, 4 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index ec30f34880..bb63b336e9 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -10,14 +10,17 @@ #include <armnn/Types.hpp> #include <armnn/LayerSupport.hpp> +#include <armnn/ILayerSupport.hpp> -#include <backendsCommon/LayerSupportRegistry.hpp> +#include <backendsCommon/BackendRegistry.hpp> #include <backendsCommon/WorkloadFactory.hpp> +#include <backendsCommon/IBackendInternal.hpp> #include <boost/cast.hpp> #include <boost/iterator/transform_iterator.hpp> #include <cstring> +#include <sstream> namespace armnn { @@ -66,9 +69,20 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, bool result; const Layer& layer = *(boost::polymorphic_downcast<const Layer*>(&connectableLayer)); - auto const& layerSupportRegistry = LayerSupportRegistryInstance(); - auto layerSupportFactory = layerSupportRegistry.GetFactory(backendId); - auto layerSupportObject = layerSupportFactory(); + auto const& backendRegistry = BackendRegistryInstance(); + if (!backendRegistry.IsBackendRegistered(backendId)) + { + std::stringstream ss; + ss << connectableLayer.GetName() << " is not supported on " << backendId + << " because this backend is not registered."; + + outReasonIfUnsupported = ss.str(); + return false; + } + + auto backendFactory = backendRegistry.GetFactory(backendId); + auto backendObject = backendFactory(); + auto layerSupportObject = backendObject->GetLayerSupport(); switch(layer.GetType()) { |