diff options
author | David Beck <david.beck@arm.com> | 2018-10-23 13:35:58 +0100 |
---|---|---|
committer | Matthew Bentham <matthew.bentham@arm.com> | 2018-10-25 09:49:58 +0100 |
commit | 29c75de868ac3a86a70b25f8da0d0c7e47d40803 (patch) | |
tree | db8dd31d26622ca252f6e2d2c86c7e20a0829e9e /src/backends/WorkloadFactory.cpp | |
parent | 5cc8e56b4ca8d58dc11973c49c10a02a2f13580c (diff) | |
download | armnn-29c75de868ac3a86a70b25f8da0d0c7e47d40803.tar.gz |
IVGCVSW-2067 : dynamically create workload factories based on the backends in the network
Change-Id: Ide594db8c79ff67642721d8bad47624b88621fbd
Diffstat (limited to 'src/backends/WorkloadFactory.cpp')
-rw-r--r-- | src/backends/WorkloadFactory.cpp | 54 |
1 files changed, 26 insertions, 28 deletions
diff --git a/src/backends/WorkloadFactory.cpp b/src/backends/WorkloadFactory.cpp index e7dec49db4..fea383f030 100644 --- a/src/backends/WorkloadFactory.cpp +++ b/src/backends/WorkloadFactory.cpp @@ -5,10 +5,6 @@ #include <backends/WorkloadFactory.hpp> #include <backends/LayerSupportRegistry.hpp> -#include <backends/reference/RefWorkloadFactory.hpp> -#include <backends/neon/NeonWorkloadFactory.hpp> -#include <backends/cl/ClWorkloadFactory.hpp> - #include <armnn/Types.hpp> #include <armnn/LayerSupport.hpp> #include <Layer.hpp> @@ -24,40 +20,42 @@ namespace armnn namespace { - const TensorInfo OverrideDataType(const TensorInfo& info, boost::optional<DataType> type) + +const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type) +{ + if (!type) { - if (type == boost::none) - { - return info; - } + return info; + } + + return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset()); +} - return TensorInfo(info.GetShape(), type.get(), info.GetQuantizationScale(), info.GetQuantizationOffset()); +Optional<DataType> GetBiasTypeFromWeightsType(Optional<DataType> weightsType) +{ + if (!weightsType) + { + return weightsType; } - boost::optional<DataType> GetBiasTypeFromWeightsType(boost::optional<DataType> weightsType) + switch(weightsType.value()) { - if (weightsType == boost::none) - { + case DataType::Float16: + case DataType::Float32: return weightsType; - } - - switch(weightsType.get()) - { - case DataType::Float16: - case DataType::Float32: - return weightsType; - case DataType::QuantisedAsymm8: - return DataType::Signed32; - default: - BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type."); - } - return boost::none; + case DataType::QuantisedAsymm8: + return DataType::Signed32; + default: + BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type."); } + return EmptyOptional(); } +} // anonymous namespace + bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, const IConnectableLayer& connectableLayer, - boost::optional<DataType> dataType, + Optional<DataType> dataType, std::string& outReasonIfUnsupported) { Optional<std::string&> reason = outReasonIfUnsupported; @@ -589,7 +587,7 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, } bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer, - boost::optional<DataType> dataType, + Optional<DataType> dataType, std::string& outReasonIfUnsupported) { auto layer = boost::polymorphic_downcast<const Layer*>(&connectableLayer); |