From 29c75de868ac3a86a70b25f8da0d0c7e47d40803 Mon Sep 17 00:00:00 2001 From: David Beck Date: Tue, 23 Oct 2018 13:35:58 +0100 Subject: IVGCVSW-2067 : dynamically create workload factories based on the backends in the network Change-Id: Ide594db8c79ff67642721d8bad47624b88621fbd --- src/backends/WorkloadFactory.cpp | 54 +++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 28 deletions(-) (limited to 'src/backends/WorkloadFactory.cpp') diff --git a/src/backends/WorkloadFactory.cpp b/src/backends/WorkloadFactory.cpp index e7dec49db4..fea383f030 100644 --- a/src/backends/WorkloadFactory.cpp +++ b/src/backends/WorkloadFactory.cpp @@ -5,10 +5,6 @@ #include #include -#include -#include -#include - #include #include #include @@ -24,40 +20,42 @@ namespace armnn namespace { - const TensorInfo OverrideDataType(const TensorInfo& info, boost::optional type) + +const TensorInfo OverrideDataType(const TensorInfo& info, Optional type) +{ + if (!type) { - if (type == boost::none) - { - return info; - } + return info; + } + + return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset()); +} - return TensorInfo(info.GetShape(), type.get(), info.GetQuantizationScale(), info.GetQuantizationOffset()); +Optional GetBiasTypeFromWeightsType(Optional weightsType) +{ + if (!weightsType) + { + return weightsType; } - boost::optional GetBiasTypeFromWeightsType(boost::optional weightsType) + switch(weightsType.value()) { - if (weightsType == boost::none) - { + case DataType::Float16: + case DataType::Float32: return weightsType; - } - - switch(weightsType.get()) - { - case DataType::Float16: - case DataType::Float32: - return weightsType; - case DataType::QuantisedAsymm8: - return DataType::Signed32; - default: - BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type."); - } - return boost::none; + case DataType::QuantisedAsymm8: + return DataType::Signed32; + default: + BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type."); } + return EmptyOptional(); } +} // anonymous namespace + bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, const IConnectableLayer& connectableLayer, - boost::optional dataType, + Optional dataType, std::string& outReasonIfUnsupported) { Optional reason = outReasonIfUnsupported; @@ -589,7 +587,7 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, } bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer, - boost::optional dataType, + Optional dataType, std::string& outReasonIfUnsupported) { auto layer = boost::polymorphic_downcast(&connectableLayer); -- cgit v1.2.1