diff options
Diffstat (limited to 'src/armnnSerializer')
-rw-r--r-- | src/armnnSerializer/ArmnnSchema.fbs | 7 | ||||
-rw-r--r-- | src/armnnSerializer/Serializer.cpp | 17 | ||||
-rw-r--r-- | src/armnnSerializer/Serializer.hpp | 7 | ||||
-rw-r--r-- | src/armnnSerializer/test/SerializerTests.cpp | 21 |
4 files changed, 40 insertions, 12 deletions
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs index 18415ce785..6a388db699 100644 --- a/src/armnnSerializer/ArmnnSchema.fbs +++ b/src/armnnSerializer/ArmnnSchema.fbs @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -312,6 +312,11 @@ table FullyConnectedDescriptor { table GatherLayer { base:LayerBase; + descriptor:GatherDescriptor; +} + +table GatherDescriptor { + axis:int = 0; } /// @deprecated Use ComparisonLayer instead diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index 17076c62ab..6555a34be7 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -493,12 +493,23 @@ void SerializerVisitor::VisitFloorLayer(const armnn::IConnectableLayer *layer, c CreateAnyLayer(flatBufferFloorLayer.o, serializer::Layer::Layer_FloorLayer); } -void SerializerVisitor::VisitGatherLayer(const armnn::IConnectableLayer* layer, const char* name) +void SerializerVisitor::VisitGatherLayer(const armnn::IConnectableLayer* layer, + const char* name) +{ + armnn::GatherDescriptor gatherDescriptor{}; + VisitGatherLayer(layer, gatherDescriptor, name); +} + +void SerializerVisitor::VisitGatherLayer(const armnn::IConnectableLayer* layer, + const armnn::GatherDescriptor& gatherDescriptor, + const char* name) { IgnoreUnused(name); + auto fbGatherDescriptor = CreateGatherDescriptor(m_flatBufferBuilder, + gatherDescriptor.m_Axis); auto fbGatherBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Gather); - auto flatBufferLayer = serializer::CreateGatherLayer(m_flatBufferBuilder, fbGatherBaseLayer); + auto flatBufferLayer = serializer::CreateGatherLayer(m_flatBufferBuilder, fbGatherBaseLayer, fbGatherDescriptor); CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_GatherLayer); } diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp index 65d87b7cf7..e4104dda8e 100644 --- a/src/armnnSerializer/Serializer.hpp +++ b/src/armnnSerializer/Serializer.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once @@ -134,9 +134,14 @@ public: const armnn::Optional<armnn::ConstTensor>& biases, const char* name = nullptr) override; + ARMNN_DEPRECATED_MSG("Use VisitGatherLayer with descriptor instead") void VisitGatherLayer(const armnn::IConnectableLayer* layer, const char* name = nullptr) override; + void VisitGatherLayer(const armnn::IConnectableLayer* layer, + const armnn::GatherDescriptor& gatherDescriptor, + const char* name = nullptr) override; + ARMNN_DEPRECATED_MSG("Use VisitComparisonLayer instead") void VisitGreaterLayer(const armnn::IConnectableLayer* layer, const char* name = nullptr) override; diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index fa43e09647..088282a18a 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -1333,17 +1333,22 @@ BOOST_AUTO_TEST_CASE(SerializeFullyConnected) BOOST_AUTO_TEST_CASE(SerializeGather) { - class GatherLayerVerifier : public LayerVerifierBase + using GatherDescriptor = armnn::GatherDescriptor; + class GatherLayerVerifier : public LayerVerifierBaseWithDescriptor<GatherDescriptor> { public: GatherLayerVerifier(const std::string& layerName, const std::vector<armnn::TensorInfo>& inputInfos, - const std::vector<armnn::TensorInfo>& outputInfos) - : LayerVerifierBase(layerName, inputInfos, outputInfos) {} + const std::vector<armnn::TensorInfo>& outputInfos, + const GatherDescriptor& descriptor) + : LayerVerifierBaseWithDescriptor<GatherDescriptor>(layerName, inputInfos, outputInfos, descriptor) {} - void VisitGatherLayer(const armnn::IConnectableLayer* layer, const char *name) override + void VisitGatherLayer(const armnn::IConnectableLayer* layer, + const GatherDescriptor& descriptor, + const char *name) override { VerifyNameAndConnections(layer, name); + BOOST_CHECK(descriptor.m_Axis == m_Descriptor.m_Axis); } void VisitConstantLayer(const armnn::IConnectableLayer*, @@ -1355,6 +1360,8 @@ BOOST_AUTO_TEST_CASE(SerializeGather) armnn::TensorInfo paramsInfo({ 8 }, armnn::DataType::QAsymmU8); armnn::TensorInfo outputInfo({ 3 }, armnn::DataType::QAsymmU8); const armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32); + GatherDescriptor descriptor; + descriptor.m_Axis = 1; paramsInfo.SetQuantizationScale(1.0f); paramsInfo.SetQuantizationOffset(0); @@ -1367,7 +1374,7 @@ BOOST_AUTO_TEST_CASE(SerializeGather) armnn::IConnectableLayer *const inputLayer = network->AddInputLayer(0); armnn::IConnectableLayer *const constantLayer = network->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData)); - armnn::IConnectableLayer *const gatherLayer = network->AddGatherLayer(layerName.c_str()); + armnn::IConnectableLayer *const gatherLayer = network->AddGatherLayer(descriptor, layerName.c_str()); armnn::IConnectableLayer *const outputLayer = network->AddOutputLayer(0); inputLayer->GetOutputSlot(0).Connect(gatherLayer->GetInputSlot(0)); @@ -1381,7 +1388,7 @@ BOOST_AUTO_TEST_CASE(SerializeGather) armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); BOOST_CHECK(deserializedNetwork); - GatherLayerVerifier verifier(layerName, {paramsInfo, indicesInfo}, {outputInfo}); + GatherLayerVerifier verifier(layerName, {paramsInfo, indicesInfo}, {outputInfo}, descriptor); deserializedNetwork->Accept(verifier); } |