aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadFactory.cpp
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2021-11-24 15:47:28 +0000
committerSadik Armagan <sadik.armagan@arm.com>2021-12-14 11:02:41 +0000
commita097d2a0ed8e30d5aaf6d29ec18d0c39201b7b67 (patch)
tree947e587bc42d07f52c55b155308b5ea5bd3ebacd /src/backends/backendsCommon/WorkloadFactory.cpp
parentbc14881a76699dd942e94265116da68a6466455e (diff)
downloadarmnn-a097d2a0ed8e30d5aaf6d29ec18d0c39201b7b67.tar.gz
IVGCVSW-6453 'Move the ArmNN Test Utils code to a physically separate directory'
* Created include/armnnTestUtils directory * Moved Arm NN test utils files into armnnTestUtils directory Signed-off-by: Sadik Armagan <sadik.armagan@arm.com> Change-Id: I03ac54c645c41c52650c4c03b6a58fb1481fef5d
Diffstat (limited to 'src/backends/backendsCommon/WorkloadFactory.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp28
1 files changed, 27 insertions, 1 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index ef2a34889e..93932a83a1 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -8,6 +8,7 @@
#include <armnn/Types.hpp>
#include <armnn/LayerSupport.hpp>
+#include <armnn/backends/IBackendInternal.hpp>
#include <armnn/backends/ILayerSupport.hpp>
#include <armnn/BackendHelper.hpp>
#include <armnn/BackendRegistry.hpp>
@@ -17,7 +18,7 @@
#include <backendsCommon/WorkloadFactory.hpp>
#include <backendsCommon/TensorHandle.hpp>
-#include <backendsCommon/test/WorkloadTestUtils.hpp>
+//#include <WorkloadTestUtils.hpp>
#include <sstream>
@@ -45,6 +46,31 @@ const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> typ
} // anonymous namespace
+inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)
+{
+ if (!weightsType)
+ {
+ return weightsType;
+ }
+
+ switch(weightsType.value())
+ {
+ case armnn::DataType::BFloat16:
+ case armnn::DataType::Float16:
+ case armnn::DataType::Float32:
+ return weightsType;
+ case armnn::DataType::QAsymmS8:
+ case armnn::DataType::QAsymmU8:
+ case armnn::DataType::QSymmS8:
+ case armnn::DataType::QSymmS16:
+ return armnn::DataType::Signed32;
+ default:
+ ARMNN_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
+ }
+ return armnn::EmptyOptional();
+}
+
+
bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
const IConnectableLayer& connectableLayer,
Optional<DataType> dataType,