diff options
Diffstat (limited to 'ConversionUtils.hpp')
-rw-r--r-- | ConversionUtils.hpp | 59 |
1 files changed, 39 insertions, 20 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp index ca1f0aea..de4516c0 100644 --- a/ConversionUtils.hpp +++ b/ConversionUtils.hpp @@ -19,6 +19,7 @@ #include <boost/test/tools/floating_point_comparison.hpp> #include <log/log.h> +#include <vector> namespace armnn_driver { @@ -29,12 +30,12 @@ namespace armnn_driver struct ConversionData { - ConversionData(armnn::Compute compute) - : m_Compute(compute) - , m_Network(nullptr, nullptr) + ConversionData(const std::vector<armnn::BackendId>& backends) + : m_Backends(backends) + , m_Network(nullptr, nullptr) {} - const armnn::Compute m_Compute; + const std::vector<armnn::BackendId> m_Backends; armnn::INetworkPtr m_Network; std::vector<armnn::IOutputSlot*> m_OutputSlotForOperand; std::vector<android::nn::RunTimePoolInfo> m_MemPools; @@ -139,6 +140,24 @@ bool IsLayerSupported(const char* funcName, IsLayerSupportedFunc f, Args&&... ar } } +template<typename IsLayerSupportedFunc, typename ... Args> +bool IsLayerSupportedForAnyBackend(const char* funcName, + IsLayerSupportedFunc f, + const std::vector<armnn::BackendId>& backends, + Args&&... args) +{ + for (auto&& backend : backends) + { + if (IsLayerSupported(funcName, f, backend, std::forward<Args>(args)...)) + { + return true; + } + } + + ALOGD("%s: not supported by any specified backend", funcName); + return false; +} + armnn::TensorShape GetTensorShapeForOperand(const Operand& operand) { return armnn::TensorShape(operand.dimensions.size(), operand.dimensions.data()); @@ -809,10 +828,10 @@ LayerInputHandle ConvertToLayerInputHandle(const HalOperation& operation, ConstTensorPin tensorPin = ConvertOperandToConstTensorPin(*operand, model, data); if (tensorPin.IsValid()) { - if (!IsLayerSupported(__func__, - armnn::IsConstantSupported, - data.m_Compute, - tensorPin.GetConstTensor().GetInfo())) + if (!IsLayerSupportedForAnyBackend(__func__, + armnn::IsConstantSupported, + data.m_Backends, + tensorPin.GetConstTensor().GetInfo())) { return LayerInputHandle(); } @@ -859,12 +878,12 @@ bool ConvertToActivation(const HalOperation& operation, return false; } const armnn::TensorInfo outInfo = GetTensorInfoForOperand(*outputOperand); - if (!IsLayerSupported(__func__, - armnn::IsActivationSupported, - data.m_Compute, - input.GetTensorInfo(), - outInfo, - activationDesc)) + if (!IsLayerSupportedForAnyBackend(__func__, + armnn::IsActivationSupported, + data.m_Backends, + input.GetTensorInfo(), + outInfo, + activationDesc)) { return false; } @@ -976,12 +995,12 @@ bool ConvertPooling2d(const HalOperation& operation, } } - if (!IsLayerSupported(__func__, - armnn::IsPooling2dSupported, - data.m_Compute, - inputInfo, - outputInfo, - desc)) + if (!IsLayerSupportedForAnyBackend(__func__, + armnn::IsPooling2dSupported, + data.m_Backends, + inputInfo, + outputInfo, + desc)) { return false; } |