diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadFactory.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadFactory.cpp | 41 |
1 files changed, 38 insertions, 3 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index 1283f67660..51bc3e60cb 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -15,7 +15,6 @@ #include <armnn/utility/TransformIterator.hpp> #include <armnn/backends/WorkloadFactory.hpp> -#include <armnn/backends/TensorHandle.hpp> #include <sstream> @@ -91,7 +90,8 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId, auto backendFactory = backendRegistry.GetFactory(backendId); auto backendObject = backendFactory(); - auto layerSupportObject = LayerSupportHandle(backendObject->GetLayerSupport(modelOptions), backendId); + auto layerSupport = backendObject->GetLayerSupport(modelOptions); + auto layerSupportObject = LayerSupportHandle(layerSupport, backendId); switch(layer.GetType()) { @@ -109,6 +109,7 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId, } case LayerType::Addition: { + ARMNN_NO_DEPRECATE_WARN_BEGIN const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo(); const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); @@ -117,6 +118,7 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId, OverrideDataType(input1, dataType), OverrideDataType(output, dataType), reason); + ARMNN_NO_DEPRECATE_WARN_END break; } case LayerType::ArgMinMax: @@ -392,6 +394,24 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId, reason); break; } + case LayerType::ElementwiseBinary: + { + auto cLayer = PolymorphicDowncast<const ElementwiseBinaryLayer*>(&layer); + + const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); + const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo(); + const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); + std::vector<TensorInfo> infos = { OverrideDataType(input0, dataType), + OverrideDataType(input1, dataType), + OverrideDataType(output, dataType) }; + result = layerSupport->IsLayerSupported(LayerType::ElementwiseBinary, + infos, + cLayer->GetParameters(), + EmptyOptional(), + EmptyOptional(), + reason); + break; + } case LayerType::ElementwiseUnary: { auto cLayer = PolymorphicDowncast<const ElementwiseUnaryLayer*>(&layer); @@ -740,6 +760,7 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId, } case LayerType::Maximum: { + ARMNN_NO_DEPRECATE_WARN_BEGIN const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo(); const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); @@ -748,6 +769,7 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId, OverrideDataType(input1, dataType), OverrideDataType(output, dataType), reason); + ARMNN_NO_DEPRECATE_WARN_END break; } case LayerType::MemCopy: @@ -814,6 +836,7 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId, } case LayerType::Multiplication: { + ARMNN_NO_DEPRECATE_WARN_BEGIN const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo(); const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); @@ -822,6 +845,7 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId, OverrideDataType(input1, dataType), OverrideDataType(output, dataType), reason); + ARMNN_NO_DEPRECATE_WARN_END break; } case LayerType::Normalization: @@ -1052,6 +1076,7 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId, } case LayerType::Division: { + ARMNN_NO_DEPRECATE_WARN_BEGIN const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo(); const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); @@ -1060,6 +1085,7 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId, OverrideDataType(input1, dataType), OverrideDataType(output, dataType), reason); + ARMNN_NO_DEPRECATE_WARN_END break; } case LayerType::Rank: @@ -1254,6 +1280,7 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId, } case LayerType::Subtraction: { + ARMNN_NO_DEPRECATE_WARN_BEGIN const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo(); const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); @@ -1262,6 +1289,7 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId, OverrideDataType(input1, dataType), OverrideDataType(output, dataType), reason); + ARMNN_NO_DEPRECATE_WARN_END break; } case LayerType::Switch: @@ -1291,6 +1319,7 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId, } case LayerType::Minimum: { + ARMNN_NO_DEPRECATE_WARN_BEGIN const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo(); const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); @@ -1298,6 +1327,7 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId, OverrideDataType(input1, dataType), OverrideDataType(output, dataType), reason); + ARMNN_NO_DEPRECATE_WARN_END break; } case LayerType::Prelu: @@ -1670,6 +1700,11 @@ std::unique_ptr<IWorkload> IWorkloadFactory::CreateWorkload(LayerType type, auto divisionQueueDescriptor = PolymorphicDowncast<const DivisionQueueDescriptor*>(&descriptor); return CreateDivision(*divisionQueueDescriptor, info); } + case LayerType::ElementwiseBinary: + { + auto queueDescriptor = PolymorphicDowncast<const ElementwiseBinaryQueueDescriptor*>(&descriptor); + return CreateWorkload(LayerType::ElementwiseBinary, *queueDescriptor, info); + } case LayerType::ElementwiseUnary: { auto elementwiseUnaryQueueDescriptor |