aboutsummaryrefslogtreecommitdiff
path: root/src/backends/WorkloadFactory.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/WorkloadFactory.cpp')
-rw-r--r--src/backends/WorkloadFactory.cpp54
1 files changed, 26 insertions, 28 deletions
diff --git a/src/backends/WorkloadFactory.cpp b/src/backends/WorkloadFactory.cpp
index e7dec49db4..fea383f030 100644
--- a/src/backends/WorkloadFactory.cpp
+++ b/src/backends/WorkloadFactory.cpp
@@ -5,10 +5,6 @@
#include <backends/WorkloadFactory.hpp>
#include <backends/LayerSupportRegistry.hpp>
-#include <backends/reference/RefWorkloadFactory.hpp>
-#include <backends/neon/NeonWorkloadFactory.hpp>
-#include <backends/cl/ClWorkloadFactory.hpp>
-
#include <armnn/Types.hpp>
#include <armnn/LayerSupport.hpp>
#include <Layer.hpp>
@@ -24,40 +20,42 @@ namespace armnn
namespace
{
- const TensorInfo OverrideDataType(const TensorInfo& info, boost::optional<DataType> type)
+
+const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
+{
+ if (!type)
{
- if (type == boost::none)
- {
- return info;
- }
+ return info;
+ }
+
+ return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
+}
- return TensorInfo(info.GetShape(), type.get(), info.GetQuantizationScale(), info.GetQuantizationOffset());
+Optional<DataType> GetBiasTypeFromWeightsType(Optional<DataType> weightsType)
+{
+ if (!weightsType)
+ {
+ return weightsType;
}
- boost::optional<DataType> GetBiasTypeFromWeightsType(boost::optional<DataType> weightsType)
+ switch(weightsType.value())
{
- if (weightsType == boost::none)
- {
+ case DataType::Float16:
+ case DataType::Float32:
return weightsType;
- }
-
- switch(weightsType.get())
- {
- case DataType::Float16:
- case DataType::Float32:
- return weightsType;
- case DataType::QuantisedAsymm8:
- return DataType::Signed32;
- default:
- BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
- }
- return boost::none;
+ case DataType::QuantisedAsymm8:
+ return DataType::Signed32;
+ default:
+ BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
}
+ return EmptyOptional();
}
+} // anonymous namespace
+
bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
const IConnectableLayer& connectableLayer,
- boost::optional<DataType> dataType,
+ Optional<DataType> dataType,
std::string& outReasonIfUnsupported)
{
Optional<std::string&> reason = outReasonIfUnsupported;
@@ -589,7 +587,7 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
}
bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
- boost::optional<DataType> dataType,
+ Optional<DataType> dataType,
std::string& outReasonIfUnsupported)
{
auto layer = boost::polymorphic_downcast<const Layer*>(&connectableLayer);