diff options
Diffstat (limited to 'src')
24 files changed, 333 insertions, 2 deletions
diff --git a/src/armnn/InternalTypes.cpp b/src/armnn/InternalTypes.cpp index fe1542b162..93a4f94378 100644 --- a/src/armnn/InternalTypes.cpp +++ b/src/armnn/InternalTypes.cpp @@ -39,6 +39,7 @@ char const* GetLayerTypeAsCString(LayerType type) case LayerType::Maximum: return "Maximum"; case LayerType::Mean: return "Mean"; case LayerType::MemCopy: return "MemCopy"; + case LayerType::Merge: return "Merge"; case LayerType::Merger: return "Merger"; case LayerType::Minimum: return "Minimum"; case LayerType::Multiplication: return "Multiplication"; diff --git a/src/armnn/InternalTypes.hpp b/src/armnn/InternalTypes.hpp index 1972e9c1b5..7c7c601d95 100644 --- a/src/armnn/InternalTypes.hpp +++ b/src/armnn/InternalTypes.hpp @@ -39,6 +39,7 @@ enum class LayerType Maximum, Mean, MemCopy, + Merge, Merger, Minimum, Multiplication, diff --git a/src/armnn/LayerSupport.cpp b/src/armnn/LayerSupport.cpp index 030973306f..bc6eec891b 100644 --- a/src/armnn/LayerSupport.cpp +++ b/src/armnn/LayerSupport.cpp @@ -355,6 +355,16 @@ bool IsMemCopySupported(const BackendId &backend, FORWARD_LAYER_SUPPORT_FUNC(backend, IsMemCopySupported, input, output); } +bool IsMergeSupported(const BackendId& backend, + const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + char* reasonIfUnsupported, + size_t reasonIfUnsupportedMaxLength) +{ + FORWARD_LAYER_SUPPORT_FUNC(backend, IsMergeSupported, input0, input1, output); +} + bool IsMergerSupported(const BackendId& backend, std::vector<const TensorInfo*> inputs, const TensorInfo& output, diff --git a/src/armnn/LayersFwd.hpp b/src/armnn/LayersFwd.hpp index 9d87aeeee3..0bd68e04af 100644 --- a/src/armnn/LayersFwd.hpp +++ b/src/armnn/LayersFwd.hpp @@ -31,6 +31,7 @@ #include "layers/MaximumLayer.hpp" #include "layers/MeanLayer.hpp" #include "layers/MemCopyLayer.hpp" +#include "layers/MergeLayer.hpp" #include "layers/MergerLayer.hpp" #include "layers/MinimumLayer.hpp" #include "layers/MultiplicationLayer.hpp" @@ -102,6 +103,7 @@ DECLARE_LAYER(Lstm) DECLARE_LAYER(Maximum) DECLARE_LAYER(Mean) DECLARE_LAYER(MemCopy) +DECLARE_LAYER(Merge) DECLARE_LAYER(Merger) DECLARE_LAYER(Minimum) DECLARE_LAYER(Multiplication) diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index 6dbd4611df..73db2e88d7 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -966,6 +966,11 @@ IConnectableLayer* Network::AddGatherLayer(const char* name) return m_Graph->AddLayer<GatherLayer>(name); } +IConnectableLayer* Network::AddMergeLayer(const char* name) +{ + return m_Graph->AddLayer<MergeLayer>(name); +} + void Network::Accept(ILayerVisitor& visitor) const { for (auto layer : GetGraph()) diff --git a/src/armnn/Network.hpp b/src/armnn/Network.hpp index 782531acde..bb7b9eb6f4 100644 --- a/src/armnn/Network.hpp +++ b/src/armnn/Network.hpp @@ -174,6 +174,8 @@ public: IConnectableLayer* AddRsqrtLayer(const char* name = nullptr) override; + IConnectableLayer* AddMergeLayer(const char* name = nullptr) override; + void Accept(ILayerVisitor& visitor) const override; private: diff --git a/src/armnn/layers/MergeLayer.cpp b/src/armnn/layers/MergeLayer.cpp new file mode 100644 index 0000000000..1d4dc49379 --- /dev/null +++ b/src/armnn/layers/MergeLayer.cpp @@ -0,0 +1,65 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#include "MergeLayer.hpp" + +#include "LayerCloneBase.hpp" + +#include <backendsCommon/WorkloadData.hpp> +#include <backendsCommon/WorkloadFactory.hpp> + +namespace armnn +{ + +MergeLayer::MergeLayer(const char* name) + : Layer(2, 1, LayerType::Merge, name) +{} + +std::unique_ptr<IWorkload> MergeLayer::CreateWorkload(const Graph& graph, + const IWorkloadFactory& factory) const +{ + return nullptr; +} + +MergeLayer* MergeLayer::Clone(Graph& graph) const +{ + return CloneBase<MergeLayer>(graph, GetName()); +} + +void MergeLayer::ValidateTensorShapesFromInputs() +{ + VerifyLayerConnections(2, CHECK_LOCATION()); + + std::vector<TensorShape> inferredShapes = InferOutputShapes({ + GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), + GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape(), + }); + + BOOST_ASSERT(inferredShapes.size() == 1); + + ConditionalThrowIfNotEqual<LayerValidationException>( + "MergeLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.", + GetOutputSlot(0).GetTensorInfo().GetShape(), + inferredShapes[0]); +} + +std::vector<TensorShape> MergeLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const +{ + BOOST_ASSERT(inputShapes.size() == 2); + + ConditionalThrowIfNotEqual<LayerValidationException>( + "MergeLayer: TensorShapes set on inputs do not match", + inputShapes[0], + inputShapes[1] + ); + + return {inputShapes[0]}; +} + +void MergeLayer::Accept(ILayerVisitor& visitor) const +{ + visitor.VisitMergeLayer(this, GetName()); +} + +} // namespace armnn diff --git a/src/armnn/layers/MergeLayer.hpp b/src/armnn/layers/MergeLayer.hpp new file mode 100644 index 0000000000..66664ca952 --- /dev/null +++ b/src/armnn/layers/MergeLayer.hpp @@ -0,0 +1,47 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include "Layer.hpp" + +namespace armnn +{ + +/// This layer dequantizes the input tensor. +class MergeLayer : public Layer +{ +public: + /// Makes a workload for the Merge type. + /// @param [in] graph The graph where this layer can be found. + /// @param [in] factory The workload factory which will create the workload. + /// @return A pointer to the created workload, or nullptr if not created. + virtual std::unique_ptr<IWorkload> CreateWorkload(const Graph& graph, + const IWorkloadFactory& factory) const override; + + /// Creates a dynamically-allocated copy of this layer. + /// @param [in] graph The graph into which this layer is being cloned. + MergeLayer* Clone(Graph& graph) const override; + + /// Check if the input tensor shape(s) + /// will lead to a valid configuration of @ref MergeLayer. + void ValidateTensorShapesFromInputs() override; + + /// Infers the output shapes from given input shapes. + /// @param [in] inputShapes The input shapes layer has. + /// @return A vector to the inferred output shape. + std::vector<TensorShape> InferOutputShapes(const std::vector<TensorShape>& inputShapes) const override; + + void Accept(ILayerVisitor& visitor) const override; + +protected: + /// Constructor to create a MergeLayer. + /// @param [in] name Optional name for the layer. + MergeLayer(const char* name); + + /// Default destructor + ~MergeLayer() = default; +}; + +} // namespace armnn diff --git a/src/armnn/test/NetworkTests.cpp b/src/armnn/test/NetworkTests.cpp index 4de09a2804..dd8eb7773f 100644 --- a/src/armnn/test/NetworkTests.cpp +++ b/src/armnn/test/NetworkTests.cpp @@ -417,4 +417,56 @@ BOOST_AUTO_TEST_CASE(Network_AddQuantize) } +BOOST_AUTO_TEST_CASE(Network_AddMerge) +{ + struct Test : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy> + { + void VisitMergeLayer(const armnn::IConnectableLayer* layer, const char* name) override + { + m_Visited = true; + + BOOST_TEST(layer); + + std::string expectedName = std::string("merge"); + BOOST_TEST(std::string(layer->GetName()) == expectedName); + BOOST_TEST(std::string(name) == expectedName); + + BOOST_TEST(layer->GetNumInputSlots() == 2); + BOOST_TEST(layer->GetNumOutputSlots() == 1); + + const armnn::TensorInfo& infoIn0 = layer->GetInputSlot(0).GetConnection()->GetTensorInfo(); + BOOST_TEST((infoIn0.GetDataType() == armnn::DataType::Float32)); + + const armnn::TensorInfo& infoIn1 = layer->GetInputSlot(1).GetConnection()->GetTensorInfo(); + BOOST_TEST((infoIn1.GetDataType() == armnn::DataType::Float32)); + + const armnn::TensorInfo& infoOut = layer->GetOutputSlot(0).GetTensorInfo(); + BOOST_TEST((infoOut.GetDataType() == armnn::DataType::Float32)); + } + + bool m_Visited = false; + }; + + armnn::INetworkPtr network = armnn::INetwork::Create(); + + armnn::IConnectableLayer* input0 = network->AddInputLayer(0); + armnn::IConnectableLayer* input1 = network->AddInputLayer(1); + armnn::IConnectableLayer* merge = network->AddMergeLayer("merge"); + armnn::IConnectableLayer* output = network->AddOutputLayer(0); + + input0->GetOutputSlot(0).Connect(merge->GetInputSlot(0)); + input1->GetOutputSlot(0).Connect(merge->GetInputSlot(1)); + merge->GetOutputSlot(0).Connect(output->GetInputSlot(0)); + + const armnn::TensorInfo info({3,1}, armnn::DataType::Float32); + input0->GetOutputSlot(0).SetTensorInfo(info); + input1->GetOutputSlot(0).SetTensorInfo(info); + merge->GetOutputSlot(0).SetTensorInfo(info); + + Test testMerge; + network->Accept(testMerge); + + BOOST_TEST(testMerge.m_Visited == true); +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp index 943c6a7fed..09cdd7cad3 100644 --- a/src/armnnDeserializer/Deserializer.cpp +++ b/src/armnnDeserializer/Deserializer.cpp @@ -206,6 +206,7 @@ m_ParserFunctions(Layer_MAX+1, &Deserializer::ParseUnsupportedLayer) m_ParserFunctions[Layer_MaximumLayer] = &Deserializer::ParseMaximum; m_ParserFunctions[Layer_MeanLayer] = &Deserializer::ParseMean; m_ParserFunctions[Layer_MinimumLayer] = &Deserializer::ParseMinimum; + m_ParserFunctions[Layer_MergeLayer] = &Deserializer::ParseMerge; m_ParserFunctions[Layer_MergerLayer] = &Deserializer::ParseMerger; m_ParserFunctions[Layer_MultiplicationLayer] = &Deserializer::ParseMultiplication; m_ParserFunctions[Layer_NormalizationLayer] = &Deserializer::ParseNormalization; @@ -271,6 +272,8 @@ Deserializer::LayerBaseRawPtr Deserializer::GetBaseLayer(const GraphPtr& graphPt return graphPtr->layers()->Get(layerIndex)->layer_as_MinimumLayer()->base(); case Layer::Layer_MaximumLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_MaximumLayer()->base(); + case Layer::Layer_MergeLayer: + return graphPtr->layers()->Get(layerIndex)->layer_as_MergeLayer()->base(); case Layer::Layer_MergerLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_MergerLayer()->base(); case Layer::Layer_MultiplicationLayer: @@ -2085,4 +2088,24 @@ void Deserializer::ParseDequantize(GraphPtr graph, unsigned int layerIndex) RegisterOutputSlots(graph, layerIndex, layer); } +void Deserializer::ParseMerge(GraphPtr graph, unsigned int layerIndex) +{ + CHECK_LAYERS(graph, 0, layerIndex); + + Deserializer::TensorRawPtrVector inputs = GetInputs(graph, layerIndex); + CHECK_VALID_SIZE(inputs.size(), 2); + + Deserializer::TensorRawPtrVector outputs = GetOutputs(graph, layerIndex); + CHECK_VALID_SIZE(outputs.size(), 1); + + const std::string layerName = GetLayerName(graph, layerIndex); + IConnectableLayer* layer = m_Network->AddMergeLayer(layerName.c_str()); + + armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + RegisterInputSlots(graph, layerIndex, layer); + RegisterOutputSlots(graph, layerIndex, layer); +} + } // namespace armnnDeserializer diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp index f18c163035..df983d9086 100644 --- a/src/armnnDeserializer/Deserializer.hpp +++ b/src/armnnDeserializer/Deserializer.hpp @@ -97,6 +97,7 @@ private: void ParseMaximum(GraphPtr graph, unsigned int layerIndex); void ParseMean(GraphPtr graph, unsigned int layerIndex); void ParseMinimum(GraphPtr graph, unsigned int layerIndex); + void ParseMerge(GraphPtr graph, unsigned int layerIndex); void ParseMerger(GraphPtr graph, unsigned int layerIndex); void ParseMultiplication(GraphPtr graph, unsigned int layerIndex); void ParseNormalization(GraphPtr graph, unsigned int layerIndex); diff --git a/src/armnnDeserializer/DeserializerSupport.md b/src/armnnDeserializer/DeserializerSupport.md index 77856cf389..4e5610c569 100644 --- a/src/armnnDeserializer/DeserializerSupport.md +++ b/src/armnnDeserializer/DeserializerSupport.md @@ -25,6 +25,7 @@ The Arm NN SDK Deserialize parser currently supports the following layers: * Lstm * Maximum * Mean +* Merge * Merger * Minimum * Multiplication diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs index 3aa644dbe5..8b275b6f17 100644 --- a/src/armnnSerializer/ArmnnSchema.fbs +++ b/src/armnnSerializer/ArmnnSchema.fbs @@ -118,7 +118,8 @@ enum LayerType : uint { DetectionPostProcess = 33, Lstm = 34, Quantize = 35, - Dequantize = 36 + Dequantize = 36, + Merge = 37 } // Base layer table to be used as part of other layers @@ -524,6 +525,10 @@ table DequantizeLayer { base:LayerBase; } +table MergeLayer { + base:LayerBase; +} + union Layer { ActivationLayer, AdditionLayer, @@ -561,7 +566,8 @@ union Layer { DetectionPostProcessLayer, LstmLayer, QuantizeLayer, - DequantizeLayer + DequantizeLayer, + MergeLayer } table AnyLayer { diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index 7181f01e6b..fe30c3eee5 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -500,6 +500,14 @@ void SerializerVisitor::VisitMinimumLayer(const armnn::IConnectableLayer* layer, CreateAnyLayer(fbMinimumLayer.o, serializer::Layer::Layer_MinimumLayer); } +void SerializerVisitor::VisitMergeLayer(const armnn::IConnectableLayer* layer, const char* name) +{ + auto fbMergeBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Merge); + auto fbMergeLayer = serializer::CreateMergeLayer(m_flatBufferBuilder, fbMergeBaseLayer); + + CreateAnyLayer(fbMergeLayer.o, serializer::Layer::Layer_MergeLayer); +} + void SerializerVisitor::VisitMergerLayer(const armnn::IConnectableLayer* layer, const armnn::OriginsDescriptor& mergerDescriptor, const char* name) diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp index 5c3e48a695..775df83966 100644 --- a/src/armnnSerializer/Serializer.hpp +++ b/src/armnnSerializer/Serializer.hpp @@ -129,6 +129,9 @@ public: void VisitMaximumLayer(const armnn::IConnectableLayer* layer, const char* name = nullptr) override; + void VisitMergeLayer(const armnn::IConnectableLayer* layer, + const char* name = nullptr) override; + void VisitMergerLayer(const armnn::IConnectableLayer* layer, const armnn::OriginsDescriptor& mergerDescriptor, const char* name = nullptr) override; diff --git a/src/armnnSerializer/SerializerSupport.md b/src/armnnSerializer/SerializerSupport.md index a3c5852bd2..a8335e1e68 100644 --- a/src/armnnSerializer/SerializerSupport.md +++ b/src/armnnSerializer/SerializerSupport.md @@ -25,6 +25,7 @@ The Arm NN SDK Serializer currently supports the following layers: * Lstm * Maximum * Mean +* Merge * Merger * Minimum * Multiplication diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index 0979076476..a1ef9eef59 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -1185,6 +1185,46 @@ BOOST_AUTO_TEST_CASE(SerializeMean) deserializedNetwork->Accept(verifier); } +BOOST_AUTO_TEST_CASE(SerializeMerge) +{ + class MergeLayerVerifier : public LayerVerifierBase + { + public: + MergeLayerVerifier(const std::string& layerName, + const std::vector<armnn::TensorInfo>& inputInfos, + const std::vector<armnn::TensorInfo>& outputInfos) + : LayerVerifierBase(layerName, inputInfos, outputInfos) {} + + void VisitMergeLayer(const armnn::IConnectableLayer* layer, const char* name) override + { + VerifyNameAndConnections(layer, name); + } + }; + + const std::string layerName("merge"); + const armnn::TensorInfo info({ 1, 2, 2, 3 }, armnn::DataType::Float32); + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer0 = network->AddInputLayer(0); + armnn::IConnectableLayer* const inputLayer1 = network->AddInputLayer(1); + armnn::IConnectableLayer* const mergeLayer = network->AddMergeLayer(layerName.c_str()); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0); + + inputLayer0->GetOutputSlot(0).Connect(mergeLayer->GetInputSlot(0)); + inputLayer1->GetOutputSlot(0).Connect(mergeLayer->GetInputSlot(1)); + mergeLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + + inputLayer0->GetOutputSlot(0).SetTensorInfo(info); + inputLayer1->GetOutputSlot(0).SetTensorInfo(info); + mergeLayer->GetOutputSlot(0).SetTensorInfo(info); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + BOOST_CHECK(deserializedNetwork); + + MergeLayerVerifier verifier(layerName, {info, info}, {info}); + deserializedNetwork->Accept(verifier); +} + BOOST_AUTO_TEST_CASE(SerializeMerger) { class MergerLayerVerifier : public LayerVerifierBase diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp index 04f822cea9..fc2d502fbd 100644 --- a/src/backends/backendsCommon/LayerSupportBase.cpp +++ b/src/backends/backendsCommon/LayerSupportBase.cpp @@ -253,6 +253,14 @@ bool LayerSupportBase::IsMemCopySupported(const armnn::TensorInfo& input, return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } +bool LayerSupportBase::IsMergeSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + Optional<std::string&> reasonIfUnsupported) const +{ + return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); +} + bool LayerSupportBase::IsMergerSupported(const std::vector<const TensorInfo*> inputs, const TensorInfo& output, const OriginsDescriptor& descriptor, diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp index 7d64095667..7c38b67379 100644 --- a/src/backends/backendsCommon/LayerSupportBase.hpp +++ b/src/backends/backendsCommon/LayerSupportBase.hpp @@ -160,6 +160,11 @@ public: const TensorInfo& output, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; + bool IsMergeSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; + bool IsMergerSupported(const std::vector<const TensorInfo*> inputs, const TensorInfo& output, const OriginsDescriptor& descriptor, diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 91b1c5790b..348c864863 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1170,6 +1170,28 @@ void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const } } +void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + ValidateTwoInputs(workloadInfo, "MergeQueueDescriptor"); + ValidateSingleOutput(workloadInfo, "MergeQueueDescriptor"); + + ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], + workloadInfo.m_InputTensorInfos[1], + "MergeQueueDescriptor", + "input0", + "input1"); + + ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], + workloadInfo.m_OutputTensorInfos[0], + "MergeQueueDescriptor", + "input0", + "output"); + + const DataType dataType = workloadInfo.m_InputTensorInfos[0].GetDataType(); + ValidateTensorDataType(workloadInfo.m_InputTensorInfos[1], dataType, "MergeQueueDescriptor", "input1"); + ValidateTensorDataType(workloadInfo.m_OutputTensorInfos[0], dataType, "MergeQueueDescriptor", "output"); +} + void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { // This is internally generated so it should not need validation. diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp index 5640701d82..1bf735288d 100644 --- a/src/backends/backendsCommon/WorkloadData.hpp +++ b/src/backends/backendsCommon/WorkloadData.hpp @@ -421,4 +421,9 @@ struct DequantizeQueueDescriptor : QueueDescriptor void Validate(const WorkloadInfo& workloadInfo) const; }; +struct MergeQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + } //namespace armnn diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index 6534a00343..4ea3ea9f9b 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -519,6 +519,18 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, reason); break; } + case LayerType::Merge: + { + const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); + const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo(); + const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); + + result = layerSupportObject->IsMergeSupported(OverrideDataType(input0, dataType), + OverrideDataType(input1, dataType), + OverrideDataType(output, dataType), + reason); + break; + } case LayerType::Merger: { auto cLayer = boost::polymorphic_downcast<const MergerLayer*>(&layer); @@ -915,6 +927,12 @@ std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDes return std::unique_ptr<IWorkload>(); } +std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return std::unique_ptr<IWorkload>(); +} + std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor, const WorkloadInfo& info) const { diff --git a/src/backends/backendsCommon/WorkloadFactory.hpp b/src/backends/backendsCommon/WorkloadFactory.hpp index ed7303cf33..889bc9d595 100644 --- a/src/backends/backendsCommon/WorkloadFactory.hpp +++ b/src/backends/backendsCommon/WorkloadFactory.hpp @@ -121,6 +121,9 @@ public: virtual std::unique_ptr<IWorkload> CreateMemCopy(const MemCopyQueueDescriptor& descriptor, const WorkloadInfo& info) const; + virtual std::unique_ptr<IWorkload> CreateMerge(const MergeQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + virtual std::unique_ptr<IWorkload> CreateMerger(const MergerQueueDescriptor& descriptor, const WorkloadInfo& info) const; diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp index 26fb03f55d..0588607a82 100644 --- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp +++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp @@ -362,6 +362,8 @@ DECLARE_LAYER_POLICY_1_PARAM(Maximum) DECLARE_LAYER_POLICY_2_PARAM(Mean) +DECLARE_LAYER_POLICY_1_PARAM(Merge) + DECLARE_LAYER_POLICY_2_PARAM(Merger) DECLARE_LAYER_POLICY_1_PARAM(Minimum) |