diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadFactory.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadFactory.cpp | 28 |
1 files changed, 27 insertions, 1 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index ef2a34889e..93932a83a1 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -8,6 +8,7 @@ #include <armnn/Types.hpp> #include <armnn/LayerSupport.hpp> +#include <armnn/backends/IBackendInternal.hpp> #include <armnn/backends/ILayerSupport.hpp> #include <armnn/BackendHelper.hpp> #include <armnn/BackendRegistry.hpp> @@ -17,7 +18,7 @@ #include <backendsCommon/WorkloadFactory.hpp> #include <backendsCommon/TensorHandle.hpp> -#include <backendsCommon/test/WorkloadTestUtils.hpp> +//#include <WorkloadTestUtils.hpp> #include <sstream> @@ -45,6 +46,31 @@ const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> typ } // anonymous namespace +inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType) +{ + if (!weightsType) + { + return weightsType; + } + + switch(weightsType.value()) + { + case armnn::DataType::BFloat16: + case armnn::DataType::Float16: + case armnn::DataType::Float32: + return weightsType; + case armnn::DataType::QAsymmS8: + case armnn::DataType::QAsymmU8: + case armnn::DataType::QSymmS8: + case armnn::DataType::QSymmS16: + return armnn::DataType::Signed32; + default: + ARMNN_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type."); + } + return armnn::EmptyOptional(); +} + + bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId, const IConnectableLayer& connectableLayer, Optional<DataType> dataType, |