diff options
author | Teresa Charlin <teresa.charlinreyes@arm.com> | 2020-06-29 16:27:03 +0100 |
---|---|---|
committer | TeresaARM <teresa.charlinreyes@arm.com> | 2020-07-01 08:26:47 +0000 |
commit | 526647333571169076f5e72c9fb18c71025bf7c0 (patch) | |
tree | 6dc559a7b0fae3705172b09a88fa552926652040 /src/armnnSerializer | |
parent | cbd2c230b7ce5f26e2ccccf36b7ad450f6e1ad09 (diff) | |
download | armnn-526647333571169076f5e72c9fb18c71025bf7c0.tar.gz |
IVGCVSW-4903 Connect axis parameter in Gather from android to ACL.
!android-nn-driver:3302
Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Change-Id: Ifbc49acb5272f8a36719bb68676e44817190537d
Diffstat (limited to 'src/armnnSerializer')
-rw-r--r-- | src/armnnSerializer/ArmnnSchema.fbs | 7 | ||||
-rw-r--r-- | src/armnnSerializer/Serializer.cpp | 17 | ||||
-rw-r--r-- | src/armnnSerializer/Serializer.hpp | 7 | ||||
-rw-r--r-- | src/armnnSerializer/test/SerializerTests.cpp | 21 |
4 files changed, 40 insertions, 12 deletions
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs index 18415ce785..6a388db699 100644 --- a/src/armnnSerializer/ArmnnSchema.fbs +++ b/src/armnnSerializer/ArmnnSchema.fbs @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -312,6 +312,11 @@ table FullyConnectedDescriptor { table GatherLayer { base:LayerBase; + descriptor:GatherDescriptor; +} + +table GatherDescriptor { + axis:int = 0; } /// @deprecated Use ComparisonLayer instead diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index 17076c62ab..6555a34be7 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -493,12 +493,23 @@ void SerializerVisitor::VisitFloorLayer(const armnn::IConnectableLayer *layer, c CreateAnyLayer(flatBufferFloorLayer.o, serializer::Layer::Layer_FloorLayer); } -void SerializerVisitor::VisitGatherLayer(const armnn::IConnectableLayer* layer, const char* name) +void SerializerVisitor::VisitGatherLayer(const armnn::IConnectableLayer* layer, + const char* name) +{ + armnn::GatherDescriptor gatherDescriptor{}; + VisitGatherLayer(layer, gatherDescriptor, name); +} + +void SerializerVisitor::VisitGatherLayer(const armnn::IConnectableLayer* layer, + const armnn::GatherDescriptor& gatherDescriptor, + const char* name) { IgnoreUnused(name); + auto fbGatherDescriptor = CreateGatherDescriptor(m_flatBufferBuilder, + gatherDescriptor.m_Axis); auto fbGatherBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Gather); - auto flatBufferLayer = serializer::CreateGatherLayer(m_flatBufferBuilder, fbGatherBaseLayer); + auto flatBufferLayer = serializer::CreateGatherLayer(m_flatBufferBuilder, fbGatherBaseLayer, fbGatherDescriptor); CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_GatherLayer); } diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp index 65d87b7cf7..e4104dda8e 100644 --- a/src/armnnSerializer/Serializer.hpp +++ b/src/armnnSerializer/Serializer.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once @@ -134,9 +134,14 @@ public: const armnn::Optional<armnn::ConstTensor>& biases, const char* name = nullptr) override; + ARMNN_DEPRECATED_MSG("Use VisitGatherLayer with descriptor instead") void VisitGatherLayer(const armnn::IConnectableLayer* layer, const char* name = nullptr) override; + void VisitGatherLayer(const armnn::IConnectableLayer* layer, + const armnn::GatherDescriptor& gatherDescriptor, + const char* name = nullptr) override; + ARMNN_DEPRECATED_MSG("Use VisitComparisonLayer instead") void VisitGreaterLayer(const armnn::IConnectableLayer* layer, const char* name = nullptr) override; diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index fa43e09647..088282a18a 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -1333,17 +1333,22 @@ BOOST_AUTO_TEST_CASE(SerializeFullyConnected) BOOST_AUTO_TEST_CASE(SerializeGather) { - class GatherLayerVerifier : public LayerVerifierBase + using GatherDescriptor = armnn::GatherDescriptor; + class GatherLayerVerifier : public LayerVerifierBaseWithDescriptor<GatherDescriptor> { public: GatherLayerVerifier(const std::string& layerName, const std::vector<armnn::TensorInfo>& inputInfos, - const std::vector<armnn::TensorInfo>& outputInfos) - : LayerVerifierBase(layerName, inputInfos, outputInfos) {} + const std::vector<armnn::TensorInfo>& outputInfos, + const GatherDescriptor& descriptor) + : LayerVerifierBaseWithDescriptor<GatherDescriptor>(layerName, inputInfos, outputInfos, descriptor) {} - void VisitGatherLayer(const armnn::IConnectableLayer* layer, const char *name) override + void VisitGatherLayer(const armnn::IConnectableLayer* layer, + const GatherDescriptor& descriptor, + const char *name) override { VerifyNameAndConnections(layer, name); + BOOST_CHECK(descriptor.m_Axis == m_Descriptor.m_Axis); } void VisitConstantLayer(const armnn::IConnectableLayer*, @@ -1355,6 +1360,8 @@ BOOST_AUTO_TEST_CASE(SerializeGather) armnn::TensorInfo paramsInfo({ 8 }, armnn::DataType::QAsymmU8); armnn::TensorInfo outputInfo({ 3 }, armnn::DataType::QAsymmU8); const armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32); + GatherDescriptor descriptor; + descriptor.m_Axis = 1; paramsInfo.SetQuantizationScale(1.0f); paramsInfo.SetQuantizationOffset(0); @@ -1367,7 +1374,7 @@ BOOST_AUTO_TEST_CASE(SerializeGather) armnn::IConnectableLayer *const inputLayer = network->AddInputLayer(0); armnn::IConnectableLayer *const constantLayer = network->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData)); - armnn::IConnectableLayer *const gatherLayer = network->AddGatherLayer(layerName.c_str()); + armnn::IConnectableLayer *const gatherLayer = network->AddGatherLayer(descriptor, layerName.c_str()); armnn::IConnectableLayer *const outputLayer = network->AddOutputLayer(0); inputLayer->GetOutputSlot(0).Connect(gatherLayer->GetInputSlot(0)); @@ -1381,7 +1388,7 @@ BOOST_AUTO_TEST_CASE(SerializeGather) armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); BOOST_CHECK(deserializedNetwork); - GatherLayerVerifier verifier(layerName, {paramsInfo, indicesInfo}, {outputInfo}); + GatherLayerVerifier verifier(layerName, {paramsInfo, indicesInfo}, {outputInfo}, descriptor); deserializedNetwork->Accept(verifier); } |