From 526647333571169076f5e72c9fb18c71025bf7c0 Mon Sep 17 00:00:00 2001 From: Teresa Charlin Date: Mon, 29 Jun 2020 16:27:03 +0100 Subject: IVGCVSW-4903 Connect axis parameter in Gather from android to ACL. !android-nn-driver:3302 Signed-off-by: Teresa Charlin Change-Id: Ifbc49acb5272f8a36719bb68676e44817190537d --- src/armnn/LayerSupport.cpp | 17 +- src/armnn/Network.cpp | 11 +- src/armnn/Network.hpp | 6 +- src/armnn/layers/GatherLayer.cpp | 27 +- src/armnn/layers/GatherLayer.hpp | 11 +- src/armnn/test/OptimizerTests.cpp | 5 +- .../test/TestNameAndDescriptorLayerVisitor.cpp | 9 +- .../test/TestNameAndDescriptorLayerVisitor.hpp | 3 +- src/armnn/test/TestNameOnlyLayerVisitor.cpp | 3 +- src/armnn/test/TestNameOnlyLayerVisitor.hpp | 3 +- src/armnnDeserializer/Deserializer.cpp | 5 +- src/armnnDeserializer/test/DeserializeGather.cpp | 18 +- src/armnnSerializer/ArmnnSchema.fbs | 7 +- src/armnnSerializer/Serializer.cpp | 17 +- src/armnnSerializer/Serializer.hpp | 7 +- src/armnnSerializer/test/SerializerTests.cpp | 21 +- src/armnnTfParser/TfParser.cpp | 6 +- src/armnnTfParser/test/Gather.cpp | 26 +- src/backends/backendsCommon/LayerSupportBase.cpp | 445 +++++++++++---------- src/backends/backendsCommon/LayerSupportBase.hpp | 9 +- src/backends/backendsCommon/WorkloadData.hpp | 2 +- src/backends/backendsCommon/WorkloadFactory.cpp | 5 +- .../backendsCommon/test/GatherEndToEndTestImpl.hpp | 5 +- .../test/IsLayerSupportedTestImpl.hpp | 4 +- src/backends/cl/ClLayerSupport.cpp | 4 +- src/backends/cl/ClLayerSupport.hpp | 3 +- src/backends/cl/workloads/ClGatherWorkload.cpp | 9 +- src/backends/cl/workloads/ClGatherWorkload.hpp | 3 +- src/backends/neon/NeonLayerSupport.cpp | 6 +- src/backends/neon/NeonLayerSupport.hpp | 1 + src/backends/neon/workloads/NeonGatherWorkload.cpp | 9 +- src/backends/neon/workloads/NeonGatherWorkload.hpp | 5 +- src/backends/reference/RefLayerSupport.cpp | 8 +- src/backends/reference/RefLayerSupport.hpp | 4 +- src/backends/reference/workloads/Gather.cpp | 7 +- src/backends/reference/workloads/Gather.hpp | 5 +- .../reference/workloads/RefGatherWorkload.cpp | 4 +- 37 files changed, 438 insertions(+), 302 deletions(-) (limited to 'src') diff --git a/src/armnn/LayerSupport.cpp b/src/armnn/LayerSupport.cpp index fe5b542867..197e1afe18 100644 --- a/src/armnn/LayerSupport.cpp +++ b/src/armnn/LayerSupport.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -305,6 +305,7 @@ bool IsFullyConnectedSupported(const BackendId& backend, FORWARD_LAYER_SUPPORT_FUNC(backend, IsFullyConnectedSupported, input, output, weights, biases, descriptor); } +ARMNN_DEPRECATED_MSG("Use IsGatherSupported with descriptor instead") bool IsGatherSupported(const BackendId& backend, const TensorInfo& input0, const TensorInfo& input1, @@ -312,7 +313,19 @@ bool IsGatherSupported(const BackendId& backend, char* reasonIfUnsupported, size_t reasonIfUnsupportedMaxLength) { - FORWARD_LAYER_SUPPORT_FUNC(backend, IsGatherSupported, input0, input1, output); + const GatherDescriptor descriptor{}; + FORWARD_LAYER_SUPPORT_FUNC(backend, IsGatherSupported, input0, input1, output, descriptor); +} + +bool IsGatherSupported(const BackendId& backend, + const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + const GatherDescriptor& descriptor, + char* reasonIfUnsupported, + size_t reasonIfUnsupportedMaxLength) +{ + FORWARD_LAYER_SUPPORT_FUNC(backend, IsGatherSupported, input0, input1, output, descriptor); } bool IsGreaterSupported(const BackendId& backend, diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index 180c00b0c3..6c7314feb2 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -1724,7 +1724,14 @@ IConnectableLayer* Network::AddRsqrtLayer(const char * name) IConnectableLayer* Network::AddGatherLayer(const char* name) { - return m_Graph->AddLayer(name); + GatherDescriptor gatherDescriptor{}; + return AddGatherLayer(gatherDescriptor, name); +} + +IConnectableLayer* Network::AddGatherLayer(const GatherDescriptor& gatherDescriptor, + const char* name) +{ + return m_Graph->AddLayer(gatherDescriptor, name); } IConnectableLayer* Network::AddMergeLayer(const char* name) diff --git a/src/armnn/Network.hpp b/src/armnn/Network.hpp index cac2b3a0e6..53bf3115f1 100644 --- a/src/armnn/Network.hpp +++ b/src/armnn/Network.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once @@ -116,8 +116,12 @@ public: const ConstTensor& biases, const char* name = nullptr) override; + ARMNN_DEPRECATED_MSG("This AddGatherLayer overload is deprecated") IConnectableLayer* AddGatherLayer(const char* name = nullptr) override; + IConnectableLayer* AddGatherLayer(const GatherDescriptor& gatherDescriptor, + const char* name = nullptr) override; + IConnectableLayer* AddPermuteLayer(const PermuteDescriptor& permuteDescriptor, const char* name = nullptr) override; diff --git a/src/armnn/layers/GatherLayer.cpp b/src/armnn/layers/GatherLayer.cpp index a99913073f..3e85d25dac 100644 --- a/src/armnn/layers/GatherLayer.cpp +++ b/src/armnn/layers/GatherLayer.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -13,8 +13,8 @@ namespace armnn { -GatherLayer::GatherLayer(const char* name) - : Layer(2, 1, LayerType::Gather, name) +GatherLayer::GatherLayer(const GatherDescriptor& param, const char* name) + : LayerWithParameters(2, 1, LayerType::Gather, param, name) { } @@ -26,7 +26,7 @@ std::unique_ptr GatherLayer::CreateWorkload(const armnn::IWorkloadFac GatherLayer* GatherLayer::Clone(Graph& graph) const { - return CloneBase(graph, GetName()); + return CloneBase(graph, m_Param, GetName()); } void GatherLayer::ValidateTensorShapesFromInputs(ShapeInferenceMethod shapeInferenceMethod) @@ -44,11 +44,22 @@ void GatherLayer::ValidateTensorShapesFromInputs(ShapeInferenceMethod shapeInfer std::vector dimSizes; - for (unsigned int i = 0; i < indicesDim; ++i) + unsigned int axis = static_cast(m_Param.m_Axis); + if (m_Param.m_Axis < 0) { - dimSizes.push_back(indices.GetShape()[i]); + int32_t axis_aux = static_cast(paramsDim) + m_Param.m_Axis; + axis = static_cast (axis_aux); } - for (unsigned int i = 1; i < paramsDim; ++i) + + for (unsigned int i = 0; i < axis; ++i) + { + dimSizes.push_back(params.GetShape()[i]); + } + for (unsigned int i = axis; i < indicesDim + axis; ++i) + { + dimSizes.push_back(indices.GetShape()[i - axis]); + } + for (unsigned int i = 1 + axis; i < paramsDim; ++i) { dimSizes.push_back(params.GetShape()[i]); } @@ -63,7 +74,7 @@ void GatherLayer::ValidateTensorShapesFromInputs(ShapeInferenceMethod shapeInfer void GatherLayer::Accept(ILayerVisitor& visitor) const { - visitor.VisitGatherLayer(this, GetName()); + visitor.VisitGatherLayer(this, GetParameters(), GetName()); } } // namespace armnn diff --git a/src/armnn/layers/GatherLayer.hpp b/src/armnn/layers/GatherLayer.hpp index 598ca44dc4..d8737adbee 100644 --- a/src/armnn/layers/GatherLayer.hpp +++ b/src/armnn/layers/GatherLayer.hpp @@ -1,17 +1,17 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once -#include "Layer.hpp" +#include "LayerWithParameters.hpp" namespace armnn { /// This layer represents a Gather operator. -class GatherLayer : public Layer +class GatherLayer : public LayerWithParameters { public: /// Makes a workload for the Gather type. @@ -24,7 +24,7 @@ public: /// @param [in] graph The graph into which this layer is being cloned. GatherLayer* Clone(Graph& graph) const override; - /// Check if the input tensor shape(s) + /// Check if the input tensor shape(s). /// will lead to a valid configuration of @ref GatherLayer. /// @param [in] shapeInferenceMethod Indicates if output shape shall be overwritten or just validate. void ValidateTensorShapesFromInputs( @@ -34,8 +34,9 @@ public: protected: /// Constructor to create a GatherLayer. + /// @param [in] param GatherDescriptor to configure the stack operation. /// @param [in] name Optional name for the layer. - GatherLayer(const char* name); + GatherLayer(const GatherDescriptor& param, const char* name); /// Default destructor ~GatherLayer() = default; diff --git a/src/armnn/test/OptimizerTests.cpp b/src/armnn/test/OptimizerTests.cpp index 65ea91d402..3af50ecf3a 100644 --- a/src/armnn/test/OptimizerTests.cpp +++ b/src/armnn/test/OptimizerTests.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -458,7 +458,8 @@ void CreateGatherGraph(Graph& graph, const armnn::TensorInfo& paramsInfo, const Layer* input1 = graph.AddLayer(1, "indices"); input1->GetOutputSlot().SetTensorInfo(indicesInfo); - GatherLayer* layer = graph.AddLayer("gather"); + GatherDescriptor descriptor; + GatherLayer* layer = graph.AddLayer(descriptor, "gather"); layer->GetOutputSlot().SetTensorInfo(outputInfo); Layer* output = graph.AddLayer(0, "output"); diff --git a/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp b/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp index e07e497ab8..6ab9f9e1e4 100644 --- a/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp +++ b/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "TestNameAndDescriptorLayerVisitor.hpp" @@ -97,6 +97,12 @@ armnn::FillDescriptor GetDescriptor() return armnn::FillDescriptor(1); } +template<> +armnn::GatherDescriptor GetDescriptor() +{ + return armnn::GatherDescriptor(); +} + template<> armnn::InstanceNormalizationDescriptor GetDescriptor() { @@ -271,6 +277,7 @@ TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Comparison) TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Concat) TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(ElementwiseUnary) TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Fill) +TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Gather) TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(InstanceNormalization) TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(L2Normalization) TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(LogSoftmax) diff --git a/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp b/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp index c8df505db0..df0e055157 100644 --- a/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp +++ b/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once @@ -50,6 +50,7 @@ DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(Concat) DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(DepthToSpace) DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(ElementwiseUnary) DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(Fill) +DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(Gather) DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(InstanceNormalization) DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(L2Normalization) DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(LogSoftmax) diff --git a/src/armnn/test/TestNameOnlyLayerVisitor.cpp b/src/armnn/test/TestNameOnlyLayerVisitor.cpp index 0653b39e58..945afa8ff5 100644 --- a/src/armnn/test/TestNameOnlyLayerVisitor.cpp +++ b/src/armnn/test/TestNameOnlyLayerVisitor.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -42,7 +42,6 @@ TEST_SUITE_NAME_ONLY_LAYER_VISITOR(Addition) TEST_SUITE_NAME_ONLY_LAYER_VISITOR(Dequantize) TEST_SUITE_NAME_ONLY_LAYER_VISITOR(Division) TEST_SUITE_NAME_ONLY_LAYER_VISITOR(Floor) -TEST_SUITE_NAME_ONLY_LAYER_VISITOR(Gather) TEST_SUITE_NAME_ONLY_LAYER_VISITOR(Maximum) TEST_SUITE_NAME_ONLY_LAYER_VISITOR(Merge) TEST_SUITE_NAME_ONLY_LAYER_VISITOR(Minimum) diff --git a/src/armnn/test/TestNameOnlyLayerVisitor.hpp b/src/armnn/test/TestNameOnlyLayerVisitor.hpp index 84dfdd6539..0e1ea8eac7 100644 --- a/src/armnn/test/TestNameOnlyLayerVisitor.hpp +++ b/src/armnn/test/TestNameOnlyLayerVisitor.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once @@ -29,7 +29,6 @@ DECLARE_TEST_NAME_ONLY_LAYER_VISITOR_CLASS(Addition) DECLARE_TEST_NAME_ONLY_LAYER_VISITOR_CLASS(Dequantize) DECLARE_TEST_NAME_ONLY_LAYER_VISITOR_CLASS(Division) DECLARE_TEST_NAME_ONLY_LAYER_VISITOR_CLASS(Floor) -DECLARE_TEST_NAME_ONLY_LAYER_VISITOR_CLASS(Gather) DECLARE_TEST_NAME_ONLY_LAYER_VISITOR_CLASS(Maximum) DECLARE_TEST_NAME_ONLY_LAYER_VISITOR_CLASS(Merge) DECLARE_TEST_NAME_ONLY_LAYER_VISITOR_CLASS(Minimum) diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp index f59c757f4b..31fae2af86 100644 --- a/src/armnnDeserializer/Deserializer.cpp +++ b/src/armnnDeserializer/Deserializer.cpp @@ -2427,8 +2427,11 @@ void Deserializer::ParseGather(GraphPtr graph, unsigned int layerIndex) Deserializer::TensorRawPtrVector outputs = GetOutputs(graph, layerIndex); CHECK_VALID_SIZE(outputs.size(), 1); + armnn::GatherDescriptor descriptor; + descriptor.m_Axis = graph->layers()->Get(layerIndex)->layer_as_GatherLayer()->descriptor()->axis(); + auto layerName = GetLayerName(graph, layerIndex); - IConnectableLayer* layer = m_Network->AddGatherLayer(layerName.c_str()); + IConnectableLayer* layer = m_Network->AddGatherLayer(descriptor, layerName.c_str()); armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); diff --git a/src/armnnDeserializer/test/DeserializeGather.cpp b/src/armnnDeserializer/test/DeserializeGather.cpp index 3fdcf51aed..0f75db4abb 100644 --- a/src/armnnDeserializer/test/DeserializeGather.cpp +++ b/src/armnnDeserializer/test/DeserializeGather.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -14,10 +14,11 @@ BOOST_AUTO_TEST_SUITE(Deserializer) struct GatherFixture : public ParserFlatbuffersSerializeFixture { - explicit GatherFixture(const std::string &inputShape, - const std::string &indicesShape, - const std::string &input1Content, - const std::string &outputShape, + explicit GatherFixture(const std::string& inputShape, + const std::string& indicesShape, + const std::string& input1Content, + const std::string& outputShape, + const std::string& axis, const std::string dataType, const std::string constDataType) { @@ -94,7 +95,10 @@ struct GatherFixture : public ParserFlatbuffersSerializeFixture dimensions: )" + outputShape + R"(, dataType: )" + dataType + R"( - }}]} + }}]}, + descriptor: { + axis: )" + axis + R"( + } }}, { layer_type: "OutputLayer", @@ -127,7 +131,7 @@ struct GatherFixture : public ParserFlatbuffersSerializeFixture struct SimpleGatherFixtureFloat32 : GatherFixture { SimpleGatherFixtureFloat32() : GatherFixture("[ 3, 2, 3 ]", "[ 2, 3 ]", "[1, 2, 1, 2, 1, 0]", - "[ 2, 3, 2, 3 ]", "Float32", "IntData") {} + "[ 2, 3, 2, 3 ]", "0", "Float32", "IntData") {} }; BOOST_FIXTURE_TEST_CASE(GatherFloat32, SimpleGatherFixtureFloat32) diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs index 18415ce785..6a388db699 100644 --- a/src/armnnSerializer/ArmnnSchema.fbs +++ b/src/armnnSerializer/ArmnnSchema.fbs @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -312,6 +312,11 @@ table FullyConnectedDescriptor { table GatherLayer { base:LayerBase; + descriptor:GatherDescriptor; +} + +table GatherDescriptor { + axis:int = 0; } /// @deprecated Use ComparisonLayer instead diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index 17076c62ab..6555a34be7 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -493,12 +493,23 @@ void SerializerVisitor::VisitFloorLayer(const armnn::IConnectableLayer *layer, c CreateAnyLayer(flatBufferFloorLayer.o, serializer::Layer::Layer_FloorLayer); } -void SerializerVisitor::VisitGatherLayer(const armnn::IConnectableLayer* layer, const char* name) +void SerializerVisitor::VisitGatherLayer(const armnn::IConnectableLayer* layer, + const char* name) +{ + armnn::GatherDescriptor gatherDescriptor{}; + VisitGatherLayer(layer, gatherDescriptor, name); +} + +void SerializerVisitor::VisitGatherLayer(const armnn::IConnectableLayer* layer, + const armnn::GatherDescriptor& gatherDescriptor, + const char* name) { IgnoreUnused(name); + auto fbGatherDescriptor = CreateGatherDescriptor(m_flatBufferBuilder, + gatherDescriptor.m_Axis); auto fbGatherBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Gather); - auto flatBufferLayer = serializer::CreateGatherLayer(m_flatBufferBuilder, fbGatherBaseLayer); + auto flatBufferLayer = serializer::CreateGatherLayer(m_flatBufferBuilder, fbGatherBaseLayer, fbGatherDescriptor); CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_GatherLayer); } diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp index 65d87b7cf7..e4104dda8e 100644 --- a/src/armnnSerializer/Serializer.hpp +++ b/src/armnnSerializer/Serializer.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once @@ -134,9 +134,14 @@ public: const armnn::Optional& biases, const char* name = nullptr) override; + ARMNN_DEPRECATED_MSG("Use VisitGatherLayer with descriptor instead") void VisitGatherLayer(const armnn::IConnectableLayer* layer, const char* name = nullptr) override; + void VisitGatherLayer(const armnn::IConnectableLayer* layer, + const armnn::GatherDescriptor& gatherDescriptor, + const char* name = nullptr) override; + ARMNN_DEPRECATED_MSG("Use VisitComparisonLayer instead") void VisitGreaterLayer(const armnn::IConnectableLayer* layer, const char* name = nullptr) override; diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index fa43e09647..088282a18a 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -1333,17 +1333,22 @@ BOOST_AUTO_TEST_CASE(SerializeFullyConnected) BOOST_AUTO_TEST_CASE(SerializeGather) { - class GatherLayerVerifier : public LayerVerifierBase + using GatherDescriptor = armnn::GatherDescriptor; + class GatherLayerVerifier : public LayerVerifierBaseWithDescriptor { public: GatherLayerVerifier(const std::string& layerName, const std::vector& inputInfos, - const std::vector& outputInfos) - : LayerVerifierBase(layerName, inputInfos, outputInfos) {} + const std::vector& outputInfos, + const GatherDescriptor& descriptor) + : LayerVerifierBaseWithDescriptor(layerName, inputInfos, outputInfos, descriptor) {} - void VisitGatherLayer(const armnn::IConnectableLayer* layer, const char *name) override + void VisitGatherLayer(const armnn::IConnectableLayer* layer, + const GatherDescriptor& descriptor, + const char *name) override { VerifyNameAndConnections(layer, name); + BOOST_CHECK(descriptor.m_Axis == m_Descriptor.m_Axis); } void VisitConstantLayer(const armnn::IConnectableLayer*, @@ -1355,6 +1360,8 @@ BOOST_AUTO_TEST_CASE(SerializeGather) armnn::TensorInfo paramsInfo({ 8 }, armnn::DataType::QAsymmU8); armnn::TensorInfo outputInfo({ 3 }, armnn::DataType::QAsymmU8); const armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32); + GatherDescriptor descriptor; + descriptor.m_Axis = 1; paramsInfo.SetQuantizationScale(1.0f); paramsInfo.SetQuantizationOffset(0); @@ -1367,7 +1374,7 @@ BOOST_AUTO_TEST_CASE(SerializeGather) armnn::IConnectableLayer *const inputLayer = network->AddInputLayer(0); armnn::IConnectableLayer *const constantLayer = network->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData)); - armnn::IConnectableLayer *const gatherLayer = network->AddGatherLayer(layerName.c_str()); + armnn::IConnectableLayer *const gatherLayer = network->AddGatherLayer(descriptor, layerName.c_str()); armnn::IConnectableLayer *const outputLayer = network->AddOutputLayer(0); inputLayer->GetOutputSlot(0).Connect(gatherLayer->GetInputSlot(0)); @@ -1381,7 +1388,7 @@ BOOST_AUTO_TEST_CASE(SerializeGather) armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); BOOST_CHECK(deserializedNetwork); - GatherLayerVerifier verifier(layerName, {paramsInfo, indicesInfo}, {outputInfo}); + GatherLayerVerifier verifier(layerName, {paramsInfo, indicesInfo}, {outputInfo}, descriptor); deserializedNetwork->Accept(verifier); } diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index 7a7c5a4375..38202fcf94 100755 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -1853,6 +1853,8 @@ ParsedTfOperationPtr TfParser::ParseGather(const tensorflow::NodeDef& nodeDef, std::vector inputs = GetInputParsedTfOperationsChecked(nodeDef, 2); IOutputSlot& params = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index); IOutputSlot& indices = inputs[1].m_IndexedValue->ResolveArmnnOutputSlot(inputs[1].m_Index); + GatherDescriptor descriptor; + descriptor.m_Axis = ReadMandatoryNodeInt32Attribute(nodeDef, "axis"); // Infer shape of output tensor unsigned int paramsDim = params.GetTensorInfo().GetNumDimensions(); @@ -1874,7 +1876,7 @@ ParsedTfOperationPtr TfParser::ParseGather(const tensorflow::NodeDef& nodeDef, const TensorInfo inferredOutputInfo(inferredShape, params.GetTensorInfo().GetDataType()); - IConnectableLayer* const layer = m_Network->AddGatherLayer(nodeDef.name().c_str()); + IConnectableLayer* const layer = m_Network->AddGatherLayer(descriptor, nodeDef.name().c_str()); layer->GetOutputSlot(0).SetTensorInfo(inferredOutputInfo); params.Connect(layer->GetInputSlot(0)); diff --git a/src/armnnTfParser/test/Gather.cpp b/src/armnnTfParser/test/Gather.cpp index 8c4b891141..ab5fb7104d 100644 --- a/src/armnnTfParser/test/Gather.cpp +++ b/src/armnnTfParser/test/Gather.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -38,7 +38,8 @@ struct GatherFixture : public armnnUtils::ParserPrototxtFixture& input1Content, const std::vector& input0Dims, - const std::vector& input1Dims) + const std::vector& input1Dims, + int axis = 0) { m_Prototext = R"( node { @@ -56,6 +57,7 @@ node { shape { )"; dimsHelper(input0Dims, m_Prototext); + m_Prototext.append(R"( } } @@ -78,6 +80,7 @@ node { tensor_shape { )"); dimsHelper(input1Dims, m_Prototext); + m_Prototext.append(R"( } tensor_content: ")"); @@ -104,8 +107,18 @@ node { type: DT_FLOAT } } + attr { + key: "axis" + value { + i: )"); + m_Prototext += std::to_string(axis); + + m_Prototext.append(R"( + } + } } )"); + Setup({ { "input0", inputShape0 }, { "input1", inputShape1 } }, { "output" }); @@ -121,7 +134,8 @@ struct GatherFixture1DParams1DIndices : public GatherFixture { 4, 0, 0, 0 }, { 0, 2, 1, 3 }, { 4 }, - { 4 }) {} + { 4 }, + 0) {} }; struct GatherFixture1DParamsMultiDimIndices : public GatherFixture @@ -131,7 +145,8 @@ struct GatherFixture1DParamsMultiDimIndices : public GatherFixture { 2, 2, 1, 1 }, { 0, 1, 1, 3 }, { 4 }, - { 2, 2 }) {} + { 2, 2 }, + 0) {} }; struct GatherFixtureMultiDimParamMultiDimIndices : public GatherFixture @@ -141,7 +156,8 @@ struct GatherFixtureMultiDimParamMultiDimIndices : public GatherFixture { 2, 1, 4 }, { 1, 3, 0, 2 }, { 5, 2 }, - { 2, 2 }) {} + { 2, 2 }, + 0) {} }; BOOST_FIXTURE_TEST_CASE(ParseGather1DParams1DIndices, GatherFixture1DParams1DIndices) diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp index e509a7b929..52e615a2d9 100644 --- a/src/backends/backendsCommon/LayerSupportBase.cpp +++ b/src/backends/backendsCommon/LayerSupportBase.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -37,177 +37,177 @@ bool DefaultLayerSupport(const char* func, namespace armnn { -bool LayerSupportBase::IsAbsSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, +bool LayerSupportBase::IsAbsSupported(const TensorInfo&, // input + const TensorInfo&, // output Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsActivationSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, - const ActivationDescriptor& /*descriptor*/, +bool LayerSupportBase::IsActivationSupported(const TensorInfo&, // input + const TensorInfo&, //output + const ActivationDescriptor&, // descriptor Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsAdditionSupported(const TensorInfo& /*input0*/, - const TensorInfo& /*input1*/, - const TensorInfo& /*output*/, +bool LayerSupportBase::IsAdditionSupported(const TensorInfo&, // input0 + const TensorInfo&, // input1 + const TensorInfo&, // output Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsArgMinMaxSupported(const armnn::TensorInfo &/*input*/, - const armnn::TensorInfo &/*output*/, - const armnn::ArgMinMaxDescriptor& /*descriptor*/, +bool LayerSupportBase::IsArgMinMaxSupported(const armnn::TensorInfo&, // input + const armnn::TensorInfo&, // output + const armnn::ArgMinMaxDescriptor&, // descriptor armnn::Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsBatchNormalizationSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, - const TensorInfo& /*mean*/, - const TensorInfo& /*var*/, - const TensorInfo& /*beta*/, - const TensorInfo& /*gamma*/, - const BatchNormalizationDescriptor& /*descriptor*/, +bool LayerSupportBase::IsBatchNormalizationSupported(const TensorInfo&, //input + const TensorInfo&, // output + const TensorInfo&, //mean + const TensorInfo&, //var + const TensorInfo&, //beta + const TensorInfo&, //gamma + const BatchNormalizationDescriptor&, // descriptor Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsBatchToSpaceNdSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, - const BatchToSpaceNdDescriptor& /*descriptor*/, +bool LayerSupportBase::IsBatchToSpaceNdSupported(const TensorInfo&, // input + const TensorInfo&, // output + const BatchToSpaceNdDescriptor&, //descriptor Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsComparisonSupported(const TensorInfo& /*input0*/, - const TensorInfo& /*input1*/, - const TensorInfo& /*output*/, - const ComparisonDescriptor& /*descriptor*/, +bool LayerSupportBase::IsComparisonSupported(const TensorInfo&, // input0 + const TensorInfo&, // input1 + const TensorInfo&, // output + const ComparisonDescriptor&, // descriptor Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsConcatSupported(const std::vector /*inputs*/, - const TensorInfo& /*output*/, - const OriginsDescriptor& /*descriptor*/, +bool LayerSupportBase::IsConcatSupported(const std::vector, // inputs + const TensorInfo&, // output + const OriginsDescriptor&, // descriptor Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsConstantSupported(const TensorInfo& /*output*/, +bool LayerSupportBase::IsConstantSupported(const TensorInfo&, // output Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsConvertBf16ToFp32Supported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, +bool LayerSupportBase::IsConvertBf16ToFp32Supported(const TensorInfo&, // input + const TensorInfo&, // output Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsConvertFp16ToFp32Supported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, +bool LayerSupportBase::IsConvertFp16ToFp32Supported(const TensorInfo&, // input + const TensorInfo&, // output Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsConvertFp32ToBf16Supported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, +bool LayerSupportBase::IsConvertFp32ToBf16Supported(const TensorInfo&, // input + const TensorInfo&, // output Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsConvertFp32ToFp16Supported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, +bool LayerSupportBase::IsConvertFp32ToFp16Supported(const TensorInfo&, // input + const TensorInfo&, // output Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsConvolution2dSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, - const Convolution2dDescriptor& /*descriptor*/, - const TensorInfo& /*weights*/, - const Optional& /*biases*/, +bool LayerSupportBase::IsConvolution2dSupported(const TensorInfo&, // input + const TensorInfo&, // output + const Convolution2dDescriptor&, // descriptor + const TensorInfo&, // weights + const Optional&, // biases Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsDebugSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, +bool LayerSupportBase::IsDebugSupported(const TensorInfo&, // input + const TensorInfo&, // output Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsDepthToSpaceSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, - const DepthToSpaceDescriptor& /*descriptor*/, +bool LayerSupportBase::IsDepthToSpaceSupported(const TensorInfo&, // input + const TensorInfo&, // output + const DepthToSpaceDescriptor&, // descriptor Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsDepthwiseConvolutionSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, - const DepthwiseConvolution2dDescriptor& /*descriptor*/, - const TensorInfo& /*weights*/, - const Optional& /*biases*/, +bool LayerSupportBase::IsDepthwiseConvolutionSupported(const TensorInfo&, //input + const TensorInfo&, //output + const DepthwiseConvolution2dDescriptor&, // descriptor + const TensorInfo&, // weights + const Optional&, // biases Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsDequantizeSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, +bool LayerSupportBase::IsDequantizeSupported(const TensorInfo&, // input + const TensorInfo&, // output Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsDetectionPostProcessSupported(const TensorInfo& /*boxEncodings*/, - const TensorInfo& /*scores*/, - const TensorInfo& /*anchors*/, - const TensorInfo& /*detectionBoxes*/, - const TensorInfo& /*detectionClasses*/, - const TensorInfo& /*detectionScores*/, - const TensorInfo& /*numDetections*/, - const DetectionPostProcessDescriptor& /*descriptor*/, +bool LayerSupportBase::IsDetectionPostProcessSupported(const TensorInfo&, // boxEncodings + const TensorInfo&, // scores + const TensorInfo&, // anchors + const TensorInfo&, // detectionBoxes + const TensorInfo&, // detectionClasses + const TensorInfo&, // detectionScores + const TensorInfo&, // numDetections + const DetectionPostProcessDescriptor&, //descriptor Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, - const DepthwiseConvolution2dDescriptor& /*descriptor*/, - const TensorInfo& /*weights*/, - const Optional& /*biases*/, +bool LayerSupportBase::IsDilatedDepthwiseConvolutionSupported(const TensorInfo&, // input + const TensorInfo&, // output + const DepthwiseConvolution2dDescriptor&, // descriptor + const TensorInfo&,// weights + const Optional&, // biases Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsDivisionSupported(const TensorInfo& /*input0*/, - const TensorInfo& /*input1*/, - const TensorInfo& /*output*/, +bool LayerSupportBase::IsDivisionSupported(const TensorInfo&, // input0 + const TensorInfo&, // input1 + const TensorInfo&, // output Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); @@ -233,139 +233,148 @@ bool LayerSupportBase::IsElementwiseUnarySupported(const TensorInfo& input, return false; } -bool LayerSupportBase::IsEqualSupported(const armnn::TensorInfo& /*input0*/, - const armnn::TensorInfo& /*input1*/, - const armnn::TensorInfo& /*output*/, +bool LayerSupportBase::IsEqualSupported(const armnn::TensorInfo&, // input0 + const armnn::TensorInfo&, // input1 + const armnn::TensorInfo&, // output armnn::Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsFakeQuantizationSupported(const TensorInfo& /*input*/, - const FakeQuantizationDescriptor& /*descriptor*/, +bool LayerSupportBase::IsFakeQuantizationSupported(const TensorInfo&, // input + const FakeQuantizationDescriptor&, // descriptor Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsFillSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, - const FillDescriptor& /*descriptor*/, +bool LayerSupportBase::IsFillSupported(const TensorInfo&, // input + const TensorInfo&, // output + const FillDescriptor&, // descriptor Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsFloorSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, +bool LayerSupportBase::IsFloorSupported(const TensorInfo&, // input + const TensorInfo&, // output Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsFullyConnectedSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, - const TensorInfo& /*weights*/, - const TensorInfo& /*biases*/, - const FullyConnectedDescriptor& /*descriptor*/, +bool LayerSupportBase::IsFullyConnectedSupported(const TensorInfo&, // input + const TensorInfo&, // output + const TensorInfo&, // weights + const TensorInfo&, // biases + const FullyConnectedDescriptor&, // descriptor Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsGatherSupported(const armnn::TensorInfo& /*input0*/, - const armnn::TensorInfo& /*input1*/, - const armnn::TensorInfo& /*output*/, +bool LayerSupportBase::IsGatherSupported(const armnn::TensorInfo&, // input0 + const armnn::TensorInfo&, // input1 + const armnn::TensorInfo&, // output armnn::Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsGreaterSupported(const TensorInfo& /*input0*/, - const TensorInfo& /*input1*/, - const TensorInfo& /*output*/, +bool LayerSupportBase::IsGatherSupported(const armnn::TensorInfo&, // input0 + const armnn::TensorInfo&, // input1 + const armnn::TensorInfo&, // output + const GatherDescriptor&, // descriptor + armnn::Optional reasonIfUnsupported) const +{ + return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); +} + +bool LayerSupportBase::IsGreaterSupported(const TensorInfo&, // input0 + const TensorInfo&, // input1 + const TensorInfo&, // output Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsInputSupported(const TensorInfo& /*input*/, +bool LayerSupportBase::IsInputSupported(const TensorInfo&, // input Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsInstanceNormalizationSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, - const InstanceNormalizationDescriptor& /*descriptor*/, +bool LayerSupportBase::IsInstanceNormalizationSupported(const TensorInfo&, // input + const TensorInfo&, // output + const InstanceNormalizationDescriptor&, // descriptor Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsL2NormalizationSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, - const L2NormalizationDescriptor& /*descriptor*/, +bool LayerSupportBase::IsL2NormalizationSupported(const TensorInfo&, // input + const TensorInfo&, // output + const L2NormalizationDescriptor&, // descriptor Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsLogSoftmaxSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, - const LogSoftmaxDescriptor& /*descriptor*/, +bool LayerSupportBase::IsLogSoftmaxSupported(const TensorInfo&, // input + const TensorInfo&, // output + const LogSoftmaxDescriptor&, // descriptor Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsLstmSupported(const TensorInfo& /*input*/, - const TensorInfo& /*outputStateIn*/, - const TensorInfo& /*cellStateIn*/, - const TensorInfo& /*scratchBuffer*/, - const TensorInfo& /*outputStateOut*/, - const TensorInfo& /*cellStateOut*/, - const TensorInfo& /*output*/, - const LstmDescriptor& /*descriptor*/, - const LstmInputParamsInfo& /*paramsInfo*/, +bool LayerSupportBase::IsLstmSupported(const TensorInfo&, // input + const TensorInfo&, // outputStateIn + const TensorInfo&, // cellStateIn + const TensorInfo&, // scratchBuffer + const TensorInfo&, // outputStateOut + const TensorInfo&, // cellStateOut + const TensorInfo&, // output + const LstmDescriptor&, // descriptor + const LstmInputParamsInfo&, // paramsInfo Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsMaximumSupported(const TensorInfo& /*input0*/, - const TensorInfo& /*input1*/, - const TensorInfo& /*output*/, +bool LayerSupportBase::IsMaximumSupported(const TensorInfo&, // input0 + const TensorInfo&, // input1 + const TensorInfo&, // output Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsMeanSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, - const MeanDescriptor& /*descriptor*/, +bool LayerSupportBase::IsMeanSupported(const TensorInfo&, // input + const TensorInfo&, // output + const MeanDescriptor&, // descriptor Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsMemCopySupported(const armnn::TensorInfo& /*input*/, - const armnn::TensorInfo& /*output*/, - armnn::Optional /*reasonIfUnsupported*/) const +bool LayerSupportBase::IsMemCopySupported(const armnn::TensorInfo&, // input + const armnn::TensorInfo&, // output + armnn::Optional ) const // reasonIfUnsupported { return true; } -bool LayerSupportBase::IsMemImportSupported(const armnn::TensorInfo& /*input*/, - const armnn::TensorInfo& /*output*/, - armnn::Optional /*reasonIfUnsupported*/) const +bool LayerSupportBase::IsMemImportSupported(const armnn::TensorInfo&, // input + const armnn::TensorInfo&, // output + armnn::Optional ) const // reasonIfUnsupported { return true; } -bool LayerSupportBase::IsMergeSupported(const TensorInfo& /*input0*/, - const TensorInfo& /*input1*/, - const TensorInfo& /*output*/, +bool LayerSupportBase::IsMergeSupported(const TensorInfo&, // input0 + const TensorInfo&, // input1 + const TensorInfo&, // output Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); @@ -379,194 +388,194 @@ bool LayerSupportBase::IsMergerSupported(const std::vector in return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported); } -bool LayerSupportBase::IsMinimumSupported(const TensorInfo& /*input0*/, - const TensorInfo& /*input1*/, - const TensorInfo& /*output*/, +bool LayerSupportBase::IsMinimumSupported(const TensorInfo&, // input0 + const TensorInfo&, // input1 + const TensorInfo&, // output Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsMultiplicationSupported(const TensorInfo& /*input0*/, - const TensorInfo& /*input1*/, - const TensorInfo& /*output*/, +bool LayerSupportBase::IsMultiplicationSupported(const TensorInfo&, // input0 + const TensorInfo&, // input1 + const TensorInfo&, // output Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsNormalizationSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, - const NormalizationDescriptor& /*descriptor*/, +bool LayerSupportBase::IsNormalizationSupported(const TensorInfo&, // input + const TensorInfo&, // output + const NormalizationDescriptor&, // descriptor Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsOutputSupported(const TensorInfo& /*output*/, +bool LayerSupportBase::IsOutputSupported(const TensorInfo&, // output Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsPadSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, - const PadDescriptor& /*descriptor*/, +bool LayerSupportBase::IsPadSupported(const TensorInfo&, // input + const TensorInfo&, // output + const PadDescriptor&, // descriptor Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsPermuteSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, - const PermuteDescriptor& /*descriptor*/, +bool LayerSupportBase::IsPermuteSupported(const TensorInfo&, // input + const TensorInfo&, // output + const PermuteDescriptor&, // descriptor Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsPooling2dSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, - const Pooling2dDescriptor& /*descriptor*/, +bool LayerSupportBase::IsPooling2dSupported(const TensorInfo&, // input + const TensorInfo&, // output + const Pooling2dDescriptor&, // descriptor Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsPreCompiledSupported(const TensorInfo& /*input*/, - const PreCompiledDescriptor& /*descriptor*/, +bool LayerSupportBase::IsPreCompiledSupported(const TensorInfo&, // input + const PreCompiledDescriptor&, // descriptor Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsPreluSupported(const TensorInfo& /*input*/, - const TensorInfo& /*alpha*/, - const TensorInfo& /*output*/, +bool LayerSupportBase::IsPreluSupported(const TensorInfo&, // input + const TensorInfo&, // alpha + const TensorInfo&, // output Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsQuantizeSupported(const armnn::TensorInfo& /*input*/, - const armnn::TensorInfo& /*output*/, +bool LayerSupportBase::IsQuantizeSupported(const armnn::TensorInfo&, // input + const armnn::TensorInfo&, // output armnn::Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsQLstmSupported(const TensorInfo& /*input*/, - const TensorInfo& /*previousOutputIn*/, - const TensorInfo& /*previousCellStateIn*/, - const TensorInfo& /*outputStateOut*/, - const TensorInfo& /*cellStateOut*/, - const TensorInfo& /*output*/, - const QLstmDescriptor& /*descriptor*/, - const LstmInputParamsInfo& /*paramsInfo*/, +bool LayerSupportBase::IsQLstmSupported(const TensorInfo&, // input + const TensorInfo&, // previousOutputIn + const TensorInfo&, // previousCellStateIn + const TensorInfo&, // outputStateOut + const TensorInfo&, // cellStateOut + const TensorInfo&, // output + const QLstmDescriptor&, // descriptor + const LstmInputParamsInfo&, // paramsInfo Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsQuantizedLstmSupported(const TensorInfo& /*input*/, - const TensorInfo& /*previousCellStateIn*/, - const TensorInfo& /*previousOutputIn*/, - const TensorInfo& /*cellStateOut*/, - const TensorInfo& /*output*/, - const QuantizedLstmInputParamsInfo& /*paramsInfo*/, +bool LayerSupportBase::IsQuantizedLstmSupported(const TensorInfo&, // input + const TensorInfo&, // previousCellStateIn + const TensorInfo&, // previousOutputIn + const TensorInfo&, // cellStateOut + const TensorInfo&, // output + const QuantizedLstmInputParamsInfo&, // paramsInfo Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsReshapeSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, - const ReshapeDescriptor& /*descriptor*/, +bool LayerSupportBase::IsReshapeSupported(const TensorInfo&, // input + const TensorInfo&, // output + const ReshapeDescriptor&, // descriptor Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsResizeBilinearSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, +bool LayerSupportBase::IsResizeBilinearSupported(const TensorInfo&, // input + const TensorInfo&, // output Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsResizeSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, - const ResizeDescriptor& /*descriptor*/, +bool LayerSupportBase::IsResizeSupported(const TensorInfo&, // input + const TensorInfo&, // output + const ResizeDescriptor&, // descriptor Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsRsqrtSupported(const TensorInfo &/*input*/, - const TensorInfo &/*output*/, +bool LayerSupportBase::IsRsqrtSupported(const TensorInfo&, // input + const TensorInfo&, // output Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsSliceSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, - const SliceDescriptor& /*descriptor*/, +bool LayerSupportBase::IsSliceSupported(const TensorInfo&, // input + const TensorInfo&, // output + const SliceDescriptor&, // descriptor Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsSoftmaxSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, - const SoftmaxDescriptor& /*descriptor*/, +bool LayerSupportBase::IsSoftmaxSupported(const TensorInfo&, // input + const TensorInfo&, // output + const SoftmaxDescriptor&, // descriptor Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } /**/ -bool LayerSupportBase::IsSpaceToBatchNdSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, - const SpaceToBatchNdDescriptor& /*descriptor*/, +bool LayerSupportBase::IsSpaceToBatchNdSupported(const TensorInfo&, // input + const TensorInfo&, // output + const SpaceToBatchNdDescriptor&, // descriptor Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsSpaceToDepthSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, - const SpaceToDepthDescriptor& /*descriptor*/, +bool LayerSupportBase::IsSpaceToDepthSupported(const TensorInfo&, // input + const TensorInfo&, // output + const SpaceToDepthDescriptor&, // descriptor Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsSplitterSupported(const TensorInfo& /*input*/, - const ViewsDescriptor& /*descriptor*/, +bool LayerSupportBase::IsSplitterSupported(const TensorInfo&, // input + const ViewsDescriptor&, // descriptor Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsSplitterSupported(const TensorInfo& /*input*/, - const std::vector>& /*outputs*/, - const ViewsDescriptor& /*descriptor*/, +bool LayerSupportBase::IsSplitterSupported(const TensorInfo&, // input + const std::vector>&, // outputs + const ViewsDescriptor&, // descriptor Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsStackSupported(const std::vector& /*inputs*/, - const TensorInfo& /*output*/, - const StackDescriptor& /*descriptor*/, +bool LayerSupportBase::IsStackSupported(const std::vector&, // inputs + const TensorInfo&, // output + const StackDescriptor&, // descriptor Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsStandInSupported(const std::vector& /*inputs*/, - const std::vector& /*outputs*/, - const StandInDescriptor& /*descriptor*/, +bool LayerSupportBase::IsStandInSupported(const std::vector&, // inputs + const std::vector&, // outputs + const StandInDescriptor&, // descriptor Optional reasonIfUnsupported) const { if (reasonIfUnsupported) @@ -580,44 +589,44 @@ bool LayerSupportBase::IsStandInSupported(const std::vector& return false; } -bool LayerSupportBase::IsStridedSliceSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, - const StridedSliceDescriptor& /*descriptor*/, +bool LayerSupportBase::IsStridedSliceSupported(const TensorInfo&, // input + const TensorInfo&, // output + const StridedSliceDescriptor&, // descriptor Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsSubtractionSupported(const TensorInfo& /*input0*/, - const TensorInfo& /*input1*/, - const TensorInfo& /*output*/, +bool LayerSupportBase::IsSubtractionSupported(const TensorInfo&, // input0 + const TensorInfo&, // input1 + const TensorInfo&, // output Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsSwitchSupported(const TensorInfo& /*input0*/, - const TensorInfo& /*input1*/, - const TensorInfo& /*output0*/, - const TensorInfo& /*output1*/, +bool LayerSupportBase::IsSwitchSupported(const TensorInfo&, // input0 + const TensorInfo&, // input1 + const TensorInfo&, // output0 + const TensorInfo&, // output1 Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } -bool LayerSupportBase::IsTransposeConvolution2dSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, - const TransposeConvolution2dDescriptor& /*descriptor*/, - const TensorInfo& /*weights*/, - const Optional& /*biases*/, +bool LayerSupportBase::IsTransposeConvolution2dSupported(const TensorInfo&, // input + const TensorInfo&, // output + const TransposeConvolution2dDescriptor&, // descriptor + const TensorInfo&, // weights + const Optional&, // biases Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); -} +} -bool LayerSupportBase::IsTransposeSupported(const TensorInfo& /*input*/, - const TensorInfo& /*output*/, - const TransposeDescriptor& /*descriptor*/, +bool LayerSupportBase::IsTransposeSupported(const TensorInfo&, // input + const TensorInfo&, // output + const TransposeDescriptor&, // descriptor Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp index aff4529417..8d5535ab4e 100644 --- a/src/backends/backendsCommon/LayerSupportBase.hpp +++ b/src/backends/backendsCommon/LayerSupportBase.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -159,11 +159,18 @@ public: const FullyConnectedDescriptor& descriptor, Optional reasonIfUnsupported = EmptyOptional()) const override; + ARMNN_DEPRECATED_MSG("Use IsGatherSupported with descriptor instead") bool IsGatherSupported(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& output, Optional reasonIfUnsupported = EmptyOptional()) const override; + bool IsGatherSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + const GatherDescriptor& descriptor, + Optional reasonIfUnsupported = EmptyOptional()) const override; + ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead") bool IsGreaterSupported(const TensorInfo& input0, const TensorInfo& input1, diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp index f2f7089040..6b2c00e298 100644 --- a/src/backends/backendsCommon/WorkloadData.hpp +++ b/src/backends/backendsCommon/WorkloadData.hpp @@ -471,7 +471,7 @@ struct RsqrtQueueDescriptor : QueueDescriptor void Validate(const WorkloadInfo& workloadInfo) const; }; -struct GatherQueueDescriptor : QueueDescriptor +struct GatherQueueDescriptor : QueueDescriptorWithParameters { void Validate(const WorkloadInfo& workloadInfo) const; }; diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index d2565cf21d..788cb7e712 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -414,9 +414,12 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo(); const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); + auto cLayer = PolymorphicDowncast(&layer); + const GatherDescriptor& descriptor = cLayer->GetParameters(); result = layerSupportObject->IsGatherSupported(OverrideDataType(input0, dataType), input1, OverrideDataType(output, dataType), + descriptor, reason); break; } diff --git a/src/backends/backendsCommon/test/GatherEndToEndTestImpl.hpp b/src/backends/backendsCommon/test/GatherEndToEndTestImpl.hpp index 1c97bef467..82f94512c3 100644 --- a/src/backends/backendsCommon/test/GatherEndToEndTestImpl.hpp +++ b/src/backends/backendsCommon/test/GatherEndToEndTestImpl.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -19,9 +19,10 @@ armnn::INetworkPtr CreateGatherNetwork(const armnn::TensorInfo& paramsInfo, { armnn::INetworkPtr net(armnn::INetwork::Create()); + armnn::GatherDescriptor descriptor; armnn::IConnectableLayer* paramsLayer = net->AddInputLayer(0); armnn::IConnectableLayer* indicesLayer = net->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData)); - armnn::IConnectableLayer* gatherLayer = net->AddGatherLayer("gather"); + armnn::IConnectableLayer* gatherLayer = net->AddGatherLayer(descriptor, "gather"); armnn::IConnectableLayer* outputLayer = net->AddOutputLayer(0, "output"); Connect(paramsLayer, gatherLayer, paramsInfo, 0, 0); Connect(indicesLayer, gatherLayer, indicesInfo, 0, 1); diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp index dcd073d279..e30cbb3d31 100644 --- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp +++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once @@ -551,7 +551,7 @@ DECLARE_LAYER_POLICY_1_PARAM(Floor) DECLARE_LAYER_POLICY_2_PARAM(FullyConnected) -DECLARE_LAYER_POLICY_1_PARAM(Gather) +DECLARE_LAYER_POLICY_2_PARAM(Gather) DECLARE_LAYER_POLICY_CUSTOM_PARAM(Input, armnn::LayerBindingId) diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp index 44da423f92..0bff96345a 100644 --- a/src/backends/cl/ClLayerSupport.cpp +++ b/src/backends/cl/ClLayerSupport.cpp @@ -466,13 +466,15 @@ bool ClLayerSupport::IsFullyConnectedSupported(const TensorInfo& input, bool ClLayerSupport::IsGatherSupported(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& output, + const GatherDescriptor& descriptor, Optional reasonIfUnsupported) const { FORWARD_WORKLOAD_VALIDATE_FUNC(ClGatherWorkloadValidate, reasonIfUnsupported, input0, input1, - output); + output, + descriptor); } bool ClLayerSupport::IsGreaterSupported(const TensorInfo& input0, diff --git a/src/backends/cl/ClLayerSupport.hpp b/src/backends/cl/ClLayerSupport.hpp index d3c3295d2f..49100bf6a2 100644 --- a/src/backends/cl/ClLayerSupport.hpp +++ b/src/backends/cl/ClLayerSupport.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once @@ -127,6 +127,7 @@ public: bool IsGatherSupported(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& output, + const GatherDescriptor& descriptor, Optional reasonIfUnsupported) const override; ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead") diff --git a/src/backends/cl/workloads/ClGatherWorkload.cpp b/src/backends/cl/workloads/ClGatherWorkload.cpp index 068487039b..c76b9c7a17 100644 --- a/src/backends/cl/workloads/ClGatherWorkload.cpp +++ b/src/backends/cl/workloads/ClGatherWorkload.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2020 Arm Ltd. All rights reserved. +// Copyright © 2020 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -14,13 +14,14 @@ namespace armnn { arm_compute::Status ClGatherWorkloadValidate(const TensorInfo& input, const TensorInfo& indices, - const TensorInfo& output) + const TensorInfo& output, + const GatherDescriptor& descriptor) { const arm_compute::TensorInfo aclInput = BuildArmComputeTensorInfo(input); const arm_compute::TensorInfo aclIndices = BuildArmComputeTensorInfo(indices); const arm_compute::TensorInfo aclOutput = BuildArmComputeTensorInfo(output); - int aclAxis = ComputeAclAxis(0, input); + int aclAxis = ComputeAclAxis(descriptor.m_Axis, input); return arm_compute::CLGather::validate(&aclInput, &aclIndices, &aclOutput, aclAxis); } @@ -35,7 +36,7 @@ ClGatherWorkload::ClGatherWorkload(const GatherQueueDescriptor& descriptor, arm_compute::ICLTensor& indices = static_cast(m_Data.m_Inputs[1])->GetTensor(); arm_compute::ICLTensor& output = static_cast(m_Data.m_Outputs[0])->GetTensor(); - int aclAxis = ComputeAclAxis(0, info.m_InputTensorInfos[0]); + int aclAxis = ComputeAclAxis(descriptor.m_Parameters.m_Axis, info.m_InputTensorInfos[0]); m_Layer.configure(&input, &indices, &output, aclAxis); }; diff --git a/src/backends/cl/workloads/ClGatherWorkload.hpp b/src/backends/cl/workloads/ClGatherWorkload.hpp index 5dbeaade59..df71a99fa0 100644 --- a/src/backends/cl/workloads/ClGatherWorkload.hpp +++ b/src/backends/cl/workloads/ClGatherWorkload.hpp @@ -13,7 +13,8 @@ namespace armnn { arm_compute::Status ClGatherWorkloadValidate(const TensorInfo& input, const TensorInfo& indices, - const TensorInfo& output); + const TensorInfo& output, + const GatherDescriptor& descriptor); class ClGatherWorkload : public BaseWorkload { diff --git a/src/backends/neon/NeonLayerSupport.cpp b/src/backends/neon/NeonLayerSupport.cpp index b611bf45f9..f6b3b7627a 100644 --- a/src/backends/neon/NeonLayerSupport.cpp +++ b/src/backends/neon/NeonLayerSupport.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -447,13 +447,15 @@ bool NeonLayerSupport::IsFullyConnectedSupported(const TensorInfo& input, bool NeonLayerSupport::IsGatherSupported(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& output, + const GatherDescriptor& descriptor, Optional reasonIfUnsupported) const { FORWARD_WORKLOAD_VALIDATE_FUNC(NeonGatherWorkloadValidate, reasonIfUnsupported, input0, input1, - output); + output, + descriptor); } bool NeonLayerSupport::IsGreaterSupported(const armnn::TensorInfo& input0, diff --git a/src/backends/neon/NeonLayerSupport.hpp b/src/backends/neon/NeonLayerSupport.hpp index 7217fc8971..aff62d18d7 100644 --- a/src/backends/neon/NeonLayerSupport.hpp +++ b/src/backends/neon/NeonLayerSupport.hpp @@ -131,6 +131,7 @@ public: bool IsGatherSupported(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& output, + const GatherDescriptor& descriptor, Optional reasonIfUnsupported) const override; ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead") diff --git a/src/backends/neon/workloads/NeonGatherWorkload.cpp b/src/backends/neon/workloads/NeonGatherWorkload.cpp index 2e7c741781..2c94cb5184 100644 --- a/src/backends/neon/workloads/NeonGatherWorkload.cpp +++ b/src/backends/neon/workloads/NeonGatherWorkload.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2020 Arm Ltd. All rights reserved. +// Copyright © 2020 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -12,13 +12,14 @@ namespace armnn { arm_compute::Status NeonGatherWorkloadValidate(const TensorInfo& input, const TensorInfo& indices, - const TensorInfo& output) + const TensorInfo& output, + const GatherDescriptor& descriptor) { const arm_compute::TensorInfo aclInput = BuildArmComputeTensorInfo(input); const arm_compute::TensorInfo aclIndices = BuildArmComputeTensorInfo(indices); const arm_compute::TensorInfo aclOutput = BuildArmComputeTensorInfo(output); - int aclAxis = ComputeAclAxis(0, input); + int aclAxis = ComputeAclAxis(descriptor.m_Axis, input); return arm_compute::NEGather::validate(&aclInput, &aclIndices, &aclOutput, aclAxis); } @@ -33,7 +34,7 @@ NeonGatherWorkload::NeonGatherWorkload(const GatherQueueDescriptor& descriptor, arm_compute::ITensor& indices = PolymorphicDowncast(m_Data.m_Inputs[1])->GetTensor(); arm_compute::ITensor& output = PolymorphicDowncast(m_Data.m_Outputs[0])->GetTensor(); - int aclAxis = ComputeAclAxis(0, info.m_InputTensorInfos[0]); + int aclAxis = ComputeAclAxis(descriptor.m_Parameters.m_Axis, info.m_InputTensorInfos[0]); m_Layer.configure(&input, &indices, &output, aclAxis); } diff --git a/src/backends/neon/workloads/NeonGatherWorkload.hpp b/src/backends/neon/workloads/NeonGatherWorkload.hpp index b1b47a5069..e5b7b57629 100644 --- a/src/backends/neon/workloads/NeonGatherWorkload.hpp +++ b/src/backends/neon/workloads/NeonGatherWorkload.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2020 Arm Ltd. All rights reserved. +// Copyright © 2020 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -13,7 +13,8 @@ namespace armnn { arm_compute::Status NeonGatherWorkloadValidate(const TensorInfo& input, const TensorInfo& indices, - const TensorInfo& output); + const TensorInfo& output, + const GatherDescriptor& descriptor); class NeonGatherWorkload : public BaseWorkload { diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 18b36a5fa8..696c6d9dac 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -987,6 +987,7 @@ bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input, bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0, const armnn::TensorInfo& input1, const armnn::TensorInfo& output, + const GatherDescriptor& descriptor, armnn::Optional reasonIfUnsupported) const { bool supported = true; @@ -1001,6 +1002,11 @@ bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0, DataType::Signed32 }; + if (descriptor.m_Axis != 0) + { + reasonIfUnsupported.value() += std::string("Reference Gather: axis not supported\n"); + supported &= false; + } supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported, "Reference Gather: input type not supported"); diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp index 96bff56a42..7d2bbf240e 100644 --- a/src/backends/reference/RefLayerSupport.hpp +++ b/src/backends/reference/RefLayerSupport.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once @@ -160,6 +160,7 @@ public: bool IsGatherSupported(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& output, + const GatherDescriptor& descriptor, Optional reasonIfUnsupported = EmptyOptional()) const override; ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead") @@ -346,7 +347,6 @@ public: const TensorInfo& output, const TransposeDescriptor& descriptor, Optional reasonIfUnsupported = EmptyOptional()) const override; - }; } // namespace armnn diff --git a/src/backends/reference/workloads/Gather.cpp b/src/backends/reference/workloads/Gather.cpp index c23edcd3bd..3e2190c81b 100644 --- a/src/backends/reference/workloads/Gather.cpp +++ b/src/backends/reference/workloads/Gather.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -20,9 +20,12 @@ void Gather(const TensorInfo& paramsInfo, const TensorInfo& outputInfo, Decoder& params, const int32_t* indices, - Encoder& output) + Encoder& output, + const int32_t axis) { IgnoreUnused(outputInfo); + IgnoreUnused(axis); + const TensorShape& paramsShape = paramsInfo.GetShape(); unsigned int paramsProduct = 1; diff --git a/src/backends/reference/workloads/Gather.hpp b/src/backends/reference/workloads/Gather.hpp index 16c983eec4..1550f4b97c 100644 --- a/src/backends/reference/workloads/Gather.hpp +++ b/src/backends/reference/workloads/Gather.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -19,6 +19,7 @@ void Gather(const TensorInfo& paramsInfo, const TensorInfo& outputInfo, Decoder& params, const int32_t* indices, - Encoder& output); + Encoder& output, + const int32_t = 0); } //namespace armnn diff --git a/src/backends/reference/workloads/RefGatherWorkload.cpp b/src/backends/reference/workloads/RefGatherWorkload.cpp index 8edf14c8f8..eaeed61b0a 100644 --- a/src/backends/reference/workloads/RefGatherWorkload.cpp +++ b/src/backends/reference/workloads/RefGatherWorkload.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -29,7 +29,7 @@ void RefGatherWorkload::Execute() const std::unique_ptr> encoderPtr = MakeEncoder(outputInfo, m_Data.m_Outputs[0]->Map()); Encoder& encoder = *encoderPtr; - Gather(inputInfo0, inputInfo1, outputInfo, decoder, indicesData, encoder); + Gather(inputInfo0, inputInfo1, outputInfo, decoder, indicesData, encoder, m_Data.m_Parameters.m_Axis); } } //namespace armnn -- cgit v1.2.1