diff options
Diffstat (limited to 'src/armnn/LayerSupport.cpp')
-rw-r--r-- | src/armnn/LayerSupport.cpp | 24 |
1 files changed, 19 insertions, 5 deletions
diff --git a/src/armnn/LayerSupport.cpp b/src/armnn/LayerSupport.cpp index 5d2d205534..78a184a7ce 100644 --- a/src/armnn/LayerSupport.cpp +++ b/src/armnn/LayerSupport.cpp @@ -4,8 +4,10 @@ // #include <armnn/LayerSupport.hpp> #include <armnn/Optional.hpp> +#include <armnn/ILayerSupport.hpp> -#include <backendsCommon/LayerSupportRegistry.hpp> +#include <backendsCommon/BackendRegistry.hpp> +#include <backendsCommon/IBackendInternal.hpp> #include <boost/assert.hpp> @@ -39,10 +41,22 @@ void CopyErrorMessage(char* truncatedString, const char* fullString, size_t maxL std::string reasonIfUnsupportedFull; \ bool isSupported; \ try { \ - auto factoryFunc = LayerSupportRegistryInstance().GetFactory(backendId); \ - auto layerSupportObject = factoryFunc(); \ - isSupported = layerSupportObject->func(__VA_ARGS__, Optional<std::string&>(reasonIfUnsupportedFull)); \ - CopyErrorMessage(reasonIfUnsupported, reasonIfUnsupportedFull.c_str(), reasonIfUnsupportedMaxLength); \ + auto const& backendRegistry = BackendRegistryInstance(); \ + if (!backendRegistry.IsBackendRegistered(backendId)) \ + { \ + std::stringstream ss; \ + ss << __func__ << " is not supported on " << backendId << " because this backend is not registered."; \ + reasonIfUnsupportedFull = ss.str(); \ + isSupported = false; \ + } \ + else \ + { \ + auto factoryFunc = backendRegistry.GetFactory(backendId); \ + auto backendObject = factoryFunc(); \ + auto layerSupportObject = backendObject->GetLayerSupport(); \ + isSupported = layerSupportObject->func(__VA_ARGS__, Optional<std::string&>(reasonIfUnsupportedFull)); \ + CopyErrorMessage(reasonIfUnsupported, reasonIfUnsupportedFull.c_str(), reasonIfUnsupportedMaxLength); \ + } \ } catch (InvalidArgumentException e) { \ /* re-throwing with more context information */ \ throw InvalidArgumentException(e, "Failed to check layer support", CHECK_LOCATION()); \ |