aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadFactory.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadFactory.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp28
1 files changed, 27 insertions, 1 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index ef2a34889e..93932a83a1 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -8,6 +8,7 @@
#include <armnn/Types.hpp>
#include <armnn/LayerSupport.hpp>
+#include <armnn/backends/IBackendInternal.hpp>
#include <armnn/backends/ILayerSupport.hpp>
#include <armnn/BackendHelper.hpp>
#include <armnn/BackendRegistry.hpp>
@@ -17,7 +18,7 @@
#include <backendsCommon/WorkloadFactory.hpp>
#include <backendsCommon/TensorHandle.hpp>
-#include <backendsCommon/test/WorkloadTestUtils.hpp>
+//#include <WorkloadTestUtils.hpp>
#include <sstream>
@@ -45,6 +46,31 @@ const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> typ
} // anonymous namespace
+inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)
+{
+ if (!weightsType)
+ {
+ return weightsType;
+ }
+
+ switch(weightsType.value())
+ {
+ case armnn::DataType::BFloat16:
+ case armnn::DataType::Float16:
+ case armnn::DataType::Float32:
+ return weightsType;
+ case armnn::DataType::QAsymmS8:
+ case armnn::DataType::QAsymmU8:
+ case armnn::DataType::QSymmS8:
+ case armnn::DataType::QSymmS16:
+ return armnn::DataType::Signed32;
+ default:
+ ARMNN_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
+ }
+ return armnn::EmptyOptional();
+}
+
+
bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
const IConnectableLayer& connectableLayer,
Optional<DataType> dataType,