aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/LayerSupport.cpp
diff options
context:
space:
mode:
authorDavid Beck <david.beck@arm.com>2018-11-12 14:59:37 +0000
committerAron Virginas-Tar <aron.virginas-tar@arm.com>2018-11-12 16:02:51 +0000
commit111b5d94d7e854c21377f8d2c0b4234317a903f6 (patch)
tree68111e5d89b605c898b2327cb59b915e3ff64ce9 /src/armnn/LayerSupport.cpp
parent4e1e136cce3fca73ba49b570cfcb620f4ec574da (diff)
downloadarmnn-111b5d94d7e854c21377f8d2c0b4234317a903f6.tar.gz
IVGCVSW-2125 : Consolidate backend registries into one
Change-Id: I56da4780f8f5fcef7ff01d232d5d61bf299364bf
Diffstat (limited to 'src/armnn/LayerSupport.cpp')
-rw-r--r--src/armnn/LayerSupport.cpp24
1 files changed, 19 insertions, 5 deletions
diff --git a/src/armnn/LayerSupport.cpp b/src/armnn/LayerSupport.cpp
index 5d2d205534..78a184a7ce 100644
--- a/src/armnn/LayerSupport.cpp
+++ b/src/armnn/LayerSupport.cpp
@@ -4,8 +4,10 @@
//
#include <armnn/LayerSupport.hpp>
#include <armnn/Optional.hpp>
+#include <armnn/ILayerSupport.hpp>
-#include <backendsCommon/LayerSupportRegistry.hpp>
+#include <backendsCommon/BackendRegistry.hpp>
+#include <backendsCommon/IBackendInternal.hpp>
#include <boost/assert.hpp>
@@ -39,10 +41,22 @@ void CopyErrorMessage(char* truncatedString, const char* fullString, size_t maxL
std::string reasonIfUnsupportedFull; \
bool isSupported; \
try { \
- auto factoryFunc = LayerSupportRegistryInstance().GetFactory(backendId); \
- auto layerSupportObject = factoryFunc(); \
- isSupported = layerSupportObject->func(__VA_ARGS__, Optional<std::string&>(reasonIfUnsupportedFull)); \
- CopyErrorMessage(reasonIfUnsupported, reasonIfUnsupportedFull.c_str(), reasonIfUnsupportedMaxLength); \
+ auto const& backendRegistry = BackendRegistryInstance(); \
+ if (!backendRegistry.IsBackendRegistered(backendId)) \
+ { \
+ std::stringstream ss; \
+ ss << __func__ << " is not supported on " << backendId << " because this backend is not registered."; \
+ reasonIfUnsupportedFull = ss.str(); \
+ isSupported = false; \
+ } \
+ else \
+ { \
+ auto factoryFunc = backendRegistry.GetFactory(backendId); \
+ auto backendObject = factoryFunc(); \
+ auto layerSupportObject = backendObject->GetLayerSupport(); \
+ isSupported = layerSupportObject->func(__VA_ARGS__, Optional<std::string&>(reasonIfUnsupportedFull)); \
+ CopyErrorMessage(reasonIfUnsupported, reasonIfUnsupportedFull.c_str(), reasonIfUnsupportedMaxLength); \
+ } \
} catch (InvalidArgumentException e) { \
/* re-throwing with more context information */ \
throw InvalidArgumentException(e, "Failed to check layer support", CHECK_LOCATION()); \