aboutsummaryrefslogtreecommitdiff
path: root/src/armnnSerializer
diff options
context:
space:
mode:
authorTeresa Charlin <teresa.charlinreyes@arm.com>2020-06-29 16:27:03 +0100
committerTeresaARM <teresa.charlinreyes@arm.com>2020-07-01 08:26:47 +0000
commit526647333571169076f5e72c9fb18c71025bf7c0 (patch)
tree6dc559a7b0fae3705172b09a88fa552926652040 /src/armnnSerializer
parentcbd2c230b7ce5f26e2ccccf36b7ad450f6e1ad09 (diff)
downloadarmnn-526647333571169076f5e72c9fb18c71025bf7c0.tar.gz
IVGCVSW-4903 Connect axis parameter in Gather from android to ACL.
!android-nn-driver:3302 Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com> Change-Id: Ifbc49acb5272f8a36719bb68676e44817190537d
Diffstat (limited to 'src/armnnSerializer')
-rw-r--r--src/armnnSerializer/ArmnnSchema.fbs7
-rw-r--r--src/armnnSerializer/Serializer.cpp17
-rw-r--r--src/armnnSerializer/Serializer.hpp7
-rw-r--r--src/armnnSerializer/test/SerializerTests.cpp21
4 files changed, 40 insertions, 12 deletions
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs
index 18415ce785..6a388db699 100644
--- a/src/armnnSerializer/ArmnnSchema.fbs
+++ b/src/armnnSerializer/ArmnnSchema.fbs
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -312,6 +312,11 @@ table FullyConnectedDescriptor {
table GatherLayer {
base:LayerBase;
+ descriptor:GatherDescriptor;
+}
+
+table GatherDescriptor {
+ axis:int = 0;
}
/// @deprecated Use ComparisonLayer instead
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index 17076c62ab..6555a34be7 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -493,12 +493,23 @@ void SerializerVisitor::VisitFloorLayer(const armnn::IConnectableLayer *layer, c
CreateAnyLayer(flatBufferFloorLayer.o, serializer::Layer::Layer_FloorLayer);
}
-void SerializerVisitor::VisitGatherLayer(const armnn::IConnectableLayer* layer, const char* name)
+void SerializerVisitor::VisitGatherLayer(const armnn::IConnectableLayer* layer,
+ const char* name)
+{
+ armnn::GatherDescriptor gatherDescriptor{};
+ VisitGatherLayer(layer, gatherDescriptor, name);
+}
+
+void SerializerVisitor::VisitGatherLayer(const armnn::IConnectableLayer* layer,
+ const armnn::GatherDescriptor& gatherDescriptor,
+ const char* name)
{
IgnoreUnused(name);
+ auto fbGatherDescriptor = CreateGatherDescriptor(m_flatBufferBuilder,
+ gatherDescriptor.m_Axis);
auto fbGatherBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Gather);
- auto flatBufferLayer = serializer::CreateGatherLayer(m_flatBufferBuilder, fbGatherBaseLayer);
+ auto flatBufferLayer = serializer::CreateGatherLayer(m_flatBufferBuilder, fbGatherBaseLayer, fbGatherDescriptor);
CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_GatherLayer);
}
diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp
index 65d87b7cf7..e4104dda8e 100644
--- a/src/armnnSerializer/Serializer.hpp
+++ b/src/armnnSerializer/Serializer.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -134,9 +134,14 @@ public:
const armnn::Optional<armnn::ConstTensor>& biases,
const char* name = nullptr) override;
+ ARMNN_DEPRECATED_MSG("Use VisitGatherLayer with descriptor instead")
void VisitGatherLayer(const armnn::IConnectableLayer* layer,
const char* name = nullptr) override;
+ void VisitGatherLayer(const armnn::IConnectableLayer* layer,
+ const armnn::GatherDescriptor& gatherDescriptor,
+ const char* name = nullptr) override;
+
ARMNN_DEPRECATED_MSG("Use VisitComparisonLayer instead")
void VisitGreaterLayer(const armnn::IConnectableLayer* layer,
const char* name = nullptr) override;
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp
index fa43e09647..088282a18a 100644
--- a/src/armnnSerializer/test/SerializerTests.cpp
+++ b/src/armnnSerializer/test/SerializerTests.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -1333,17 +1333,22 @@ BOOST_AUTO_TEST_CASE(SerializeFullyConnected)
BOOST_AUTO_TEST_CASE(SerializeGather)
{
- class GatherLayerVerifier : public LayerVerifierBase
+ using GatherDescriptor = armnn::GatherDescriptor;
+ class GatherLayerVerifier : public LayerVerifierBaseWithDescriptor<GatherDescriptor>
{
public:
GatherLayerVerifier(const std::string& layerName,
const std::vector<armnn::TensorInfo>& inputInfos,
- const std::vector<armnn::TensorInfo>& outputInfos)
- : LayerVerifierBase(layerName, inputInfos, outputInfos) {}
+ const std::vector<armnn::TensorInfo>& outputInfos,
+ const GatherDescriptor& descriptor)
+ : LayerVerifierBaseWithDescriptor<GatherDescriptor>(layerName, inputInfos, outputInfos, descriptor) {}
- void VisitGatherLayer(const armnn::IConnectableLayer* layer, const char *name) override
+ void VisitGatherLayer(const armnn::IConnectableLayer* layer,
+ const GatherDescriptor& descriptor,
+ const char *name) override
{
VerifyNameAndConnections(layer, name);
+ BOOST_CHECK(descriptor.m_Axis == m_Descriptor.m_Axis);
}
void VisitConstantLayer(const armnn::IConnectableLayer*,
@@ -1355,6 +1360,8 @@ BOOST_AUTO_TEST_CASE(SerializeGather)
armnn::TensorInfo paramsInfo({ 8 }, armnn::DataType::QAsymmU8);
armnn::TensorInfo outputInfo({ 3 }, armnn::DataType::QAsymmU8);
const armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32);
+ GatherDescriptor descriptor;
+ descriptor.m_Axis = 1;
paramsInfo.SetQuantizationScale(1.0f);
paramsInfo.SetQuantizationOffset(0);
@@ -1367,7 +1374,7 @@ BOOST_AUTO_TEST_CASE(SerializeGather)
armnn::IConnectableLayer *const inputLayer = network->AddInputLayer(0);
armnn::IConnectableLayer *const constantLayer =
network->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData));
- armnn::IConnectableLayer *const gatherLayer = network->AddGatherLayer(layerName.c_str());
+ armnn::IConnectableLayer *const gatherLayer = network->AddGatherLayer(descriptor, layerName.c_str());
armnn::IConnectableLayer *const outputLayer = network->AddOutputLayer(0);
inputLayer->GetOutputSlot(0).Connect(gatherLayer->GetInputSlot(0));
@@ -1381,7 +1388,7 @@ BOOST_AUTO_TEST_CASE(SerializeGather)
armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
BOOST_CHECK(deserializedNetwork);
- GatherLayerVerifier verifier(layerName, {paramsInfo, indicesInfo}, {outputInfo});
+ GatherLayerVerifier verifier(layerName, {paramsInfo, indicesInfo}, {outputInfo}, descriptor);
deserializedNetwork->Accept(verifier);
}