aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadFactory.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadFactory.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp41
1 files changed, 38 insertions, 3 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index 1283f67660..51bc3e60cb 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -15,7 +15,6 @@
#include <armnn/utility/TransformIterator.hpp>
#include <armnn/backends/WorkloadFactory.hpp>
-#include <armnn/backends/TensorHandle.hpp>
#include <sstream>
@@ -91,7 +90,8 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
auto backendFactory = backendRegistry.GetFactory(backendId);
auto backendObject = backendFactory();
- auto layerSupportObject = LayerSupportHandle(backendObject->GetLayerSupport(modelOptions), backendId);
+ auto layerSupport = backendObject->GetLayerSupport(modelOptions);
+ auto layerSupportObject = LayerSupportHandle(layerSupport, backendId);
switch(layer.GetType())
{
@@ -109,6 +109,7 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
}
case LayerType::Addition:
{
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
@@ -117,6 +118,7 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
OverrideDataType(input1, dataType),
OverrideDataType(output, dataType),
reason);
+ ARMNN_NO_DEPRECATE_WARN_END
break;
}
case LayerType::ArgMinMax:
@@ -392,6 +394,24 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
reason);
break;
}
+ case LayerType::ElementwiseBinary:
+ {
+ auto cLayer = PolymorphicDowncast<const ElementwiseBinaryLayer*>(&layer);
+
+ const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
+ const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
+ const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
+ std::vector<TensorInfo> infos = { OverrideDataType(input0, dataType),
+ OverrideDataType(input1, dataType),
+ OverrideDataType(output, dataType) };
+ result = layerSupport->IsLayerSupported(LayerType::ElementwiseBinary,
+ infos,
+ cLayer->GetParameters(),
+ EmptyOptional(),
+ EmptyOptional(),
+ reason);
+ break;
+ }
case LayerType::ElementwiseUnary:
{
auto cLayer = PolymorphicDowncast<const ElementwiseUnaryLayer*>(&layer);
@@ -740,6 +760,7 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
}
case LayerType::Maximum:
{
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
@@ -748,6 +769,7 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
OverrideDataType(input1, dataType),
OverrideDataType(output, dataType),
reason);
+ ARMNN_NO_DEPRECATE_WARN_END
break;
}
case LayerType::MemCopy:
@@ -814,6 +836,7 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
}
case LayerType::Multiplication:
{
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
@@ -822,6 +845,7 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
OverrideDataType(input1, dataType),
OverrideDataType(output, dataType),
reason);
+ ARMNN_NO_DEPRECATE_WARN_END
break;
}
case LayerType::Normalization:
@@ -1052,6 +1076,7 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
}
case LayerType::Division:
{
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
@@ -1060,6 +1085,7 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
OverrideDataType(input1, dataType),
OverrideDataType(output, dataType),
reason);
+ ARMNN_NO_DEPRECATE_WARN_END
break;
}
case LayerType::Rank:
@@ -1254,6 +1280,7 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
}
case LayerType::Subtraction:
{
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
@@ -1262,6 +1289,7 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
OverrideDataType(input1, dataType),
OverrideDataType(output, dataType),
reason);
+ ARMNN_NO_DEPRECATE_WARN_END
break;
}
case LayerType::Switch:
@@ -1291,6 +1319,7 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
}
case LayerType::Minimum:
{
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
@@ -1298,6 +1327,7 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
OverrideDataType(input1, dataType),
OverrideDataType(output, dataType),
reason);
+ ARMNN_NO_DEPRECATE_WARN_END
break;
}
case LayerType::Prelu:
@@ -1670,6 +1700,11 @@ std::unique_ptr<IWorkload> IWorkloadFactory::CreateWorkload(LayerType type,
auto divisionQueueDescriptor = PolymorphicDowncast<const DivisionQueueDescriptor*>(&descriptor);
return CreateDivision(*divisionQueueDescriptor, info);
}
+ case LayerType::ElementwiseBinary:
+ {
+ auto queueDescriptor = PolymorphicDowncast<const ElementwiseBinaryQueueDescriptor*>(&descriptor);
+ return CreateWorkload(LayerType::ElementwiseBinary, *queueDescriptor, info);
+ }
case LayerType::ElementwiseUnary:
{
auto elementwiseUnaryQueueDescriptor