aboutsummaryrefslogtreecommitdiff
path: root/ConversionUtils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'ConversionUtils.hpp')
-rw-r--r--ConversionUtils.hpp59
1 files changed, 39 insertions, 20 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index ca1f0aea..de4516c0 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -19,6 +19,7 @@
#include <boost/test/tools/floating_point_comparison.hpp>
#include <log/log.h>
+#include <vector>
namespace armnn_driver
{
@@ -29,12 +30,12 @@ namespace armnn_driver
struct ConversionData
{
- ConversionData(armnn::Compute compute)
- : m_Compute(compute)
- , m_Network(nullptr, nullptr)
+ ConversionData(const std::vector<armnn::BackendId>& backends)
+ : m_Backends(backends)
+ , m_Network(nullptr, nullptr)
{}
- const armnn::Compute m_Compute;
+ const std::vector<armnn::BackendId> m_Backends;
armnn::INetworkPtr m_Network;
std::vector<armnn::IOutputSlot*> m_OutputSlotForOperand;
std::vector<android::nn::RunTimePoolInfo> m_MemPools;
@@ -139,6 +140,24 @@ bool IsLayerSupported(const char* funcName, IsLayerSupportedFunc f, Args&&... ar
}
}
+template<typename IsLayerSupportedFunc, typename ... Args>
+bool IsLayerSupportedForAnyBackend(const char* funcName,
+ IsLayerSupportedFunc f,
+ const std::vector<armnn::BackendId>& backends,
+ Args&&... args)
+{
+ for (auto&& backend : backends)
+ {
+ if (IsLayerSupported(funcName, f, backend, std::forward<Args>(args)...))
+ {
+ return true;
+ }
+ }
+
+ ALOGD("%s: not supported by any specified backend", funcName);
+ return false;
+}
+
armnn::TensorShape GetTensorShapeForOperand(const Operand& operand)
{
return armnn::TensorShape(operand.dimensions.size(), operand.dimensions.data());
@@ -809,10 +828,10 @@ LayerInputHandle ConvertToLayerInputHandle(const HalOperation& operation,
ConstTensorPin tensorPin = ConvertOperandToConstTensorPin(*operand, model, data);
if (tensorPin.IsValid())
{
- if (!IsLayerSupported(__func__,
- armnn::IsConstantSupported,
- data.m_Compute,
- tensorPin.GetConstTensor().GetInfo()))
+ if (!IsLayerSupportedForAnyBackend(__func__,
+ armnn::IsConstantSupported,
+ data.m_Backends,
+ tensorPin.GetConstTensor().GetInfo()))
{
return LayerInputHandle();
}
@@ -859,12 +878,12 @@ bool ConvertToActivation(const HalOperation& operation,
return false;
}
const armnn::TensorInfo outInfo = GetTensorInfoForOperand(*outputOperand);
- if (!IsLayerSupported(__func__,
- armnn::IsActivationSupported,
- data.m_Compute,
- input.GetTensorInfo(),
- outInfo,
- activationDesc))
+ if (!IsLayerSupportedForAnyBackend(__func__,
+ armnn::IsActivationSupported,
+ data.m_Backends,
+ input.GetTensorInfo(),
+ outInfo,
+ activationDesc))
{
return false;
}
@@ -976,12 +995,12 @@ bool ConvertPooling2d(const HalOperation& operation,
}
}
- if (!IsLayerSupported(__func__,
- armnn::IsPooling2dSupported,
- data.m_Compute,
- inputInfo,
- outputInfo,
- desc))
+ if (!IsLayerSupportedForAnyBackend(__func__,
+ armnn::IsPooling2dSupported,
+ data.m_Backends,
+ inputInfo,
+ outputInfo,
+ desc))
{
return false;
}