aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadFactory.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadFactory.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp22
1 files changed, 18 insertions, 4 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index ec30f34880..bb63b336e9 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -10,14 +10,17 @@
#include <armnn/Types.hpp>
#include <armnn/LayerSupport.hpp>
+#include <armnn/ILayerSupport.hpp>
-#include <backendsCommon/LayerSupportRegistry.hpp>
+#include <backendsCommon/BackendRegistry.hpp>
#include <backendsCommon/WorkloadFactory.hpp>
+#include <backendsCommon/IBackendInternal.hpp>
#include <boost/cast.hpp>
#include <boost/iterator/transform_iterator.hpp>
#include <cstring>
+#include <sstream>
namespace armnn
{
@@ -66,9 +69,20 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
bool result;
const Layer& layer = *(boost::polymorphic_downcast<const Layer*>(&connectableLayer));
- auto const& layerSupportRegistry = LayerSupportRegistryInstance();
- auto layerSupportFactory = layerSupportRegistry.GetFactory(backendId);
- auto layerSupportObject = layerSupportFactory();
+ auto const& backendRegistry = BackendRegistryInstance();
+ if (!backendRegistry.IsBackendRegistered(backendId))
+ {
+ std::stringstream ss;
+ ss << connectableLayer.GetName() << " is not supported on " << backendId
+ << " because this backend is not registered.";
+
+ outReasonIfUnsupported = ss.str();
+ return false;
+ }
+
+ auto backendFactory = backendRegistry.GetFactory(backendId);
+ auto backendObject = backendFactory();
+ auto layerSupportObject = backendObject->GetLayerSupport();
switch(layer.GetType())
{