aboutsummaryrefslogtreecommitdiff
path: root/src/armnnSerializer/test/SerializerTests.cpp
diff options
context:
space:
mode:
authorTeresa Charlin <teresa.charlinreyes@arm.com>2020-06-29 16:27:03 +0100
committerTeresaARM <teresa.charlinreyes@arm.com>2020-07-01 08:26:47 +0000
commit526647333571169076f5e72c9fb18c71025bf7c0 (patch)
tree6dc559a7b0fae3705172b09a88fa552926652040 /src/armnnSerializer/test/SerializerTests.cpp
parentcbd2c230b7ce5f26e2ccccf36b7ad450f6e1ad09 (diff)
downloadarmnn-526647333571169076f5e72c9fb18c71025bf7c0.tar.gz
IVGCVSW-4903 Connect axis parameter in Gather from android to ACL.
!android-nn-driver:3302 Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com> Change-Id: Ifbc49acb5272f8a36719bb68676e44817190537d
Diffstat (limited to 'src/armnnSerializer/test/SerializerTests.cpp')
-rw-r--r--src/armnnSerializer/test/SerializerTests.cpp21
1 files changed, 14 insertions, 7 deletions
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp
index fa43e09647..088282a18a 100644
--- a/src/armnnSerializer/test/SerializerTests.cpp
+++ b/src/armnnSerializer/test/SerializerTests.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -1333,17 +1333,22 @@ BOOST_AUTO_TEST_CASE(SerializeFullyConnected)
BOOST_AUTO_TEST_CASE(SerializeGather)
{
- class GatherLayerVerifier : public LayerVerifierBase
+ using GatherDescriptor = armnn::GatherDescriptor;
+ class GatherLayerVerifier : public LayerVerifierBaseWithDescriptor<GatherDescriptor>
{
public:
GatherLayerVerifier(const std::string& layerName,
const std::vector<armnn::TensorInfo>& inputInfos,
- const std::vector<armnn::TensorInfo>& outputInfos)
- : LayerVerifierBase(layerName, inputInfos, outputInfos) {}
+ const std::vector<armnn::TensorInfo>& outputInfos,
+ const GatherDescriptor& descriptor)
+ : LayerVerifierBaseWithDescriptor<GatherDescriptor>(layerName, inputInfos, outputInfos, descriptor) {}
- void VisitGatherLayer(const armnn::IConnectableLayer* layer, const char *name) override
+ void VisitGatherLayer(const armnn::IConnectableLayer* layer,
+ const GatherDescriptor& descriptor,
+ const char *name) override
{
VerifyNameAndConnections(layer, name);
+ BOOST_CHECK(descriptor.m_Axis == m_Descriptor.m_Axis);
}
void VisitConstantLayer(const armnn::IConnectableLayer*,
@@ -1355,6 +1360,8 @@ BOOST_AUTO_TEST_CASE(SerializeGather)
armnn::TensorInfo paramsInfo({ 8 }, armnn::DataType::QAsymmU8);
armnn::TensorInfo outputInfo({ 3 }, armnn::DataType::QAsymmU8);
const armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32);
+ GatherDescriptor descriptor;
+ descriptor.m_Axis = 1;
paramsInfo.SetQuantizationScale(1.0f);
paramsInfo.SetQuantizationOffset(0);
@@ -1367,7 +1374,7 @@ BOOST_AUTO_TEST_CASE(SerializeGather)
armnn::IConnectableLayer *const inputLayer = network->AddInputLayer(0);
armnn::IConnectableLayer *const constantLayer =
network->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData));
- armnn::IConnectableLayer *const gatherLayer = network->AddGatherLayer(layerName.c_str());
+ armnn::IConnectableLayer *const gatherLayer = network->AddGatherLayer(descriptor, layerName.c_str());
armnn::IConnectableLayer *const outputLayer = network->AddOutputLayer(0);
inputLayer->GetOutputSlot(0).Connect(gatherLayer->GetInputSlot(0));
@@ -1381,7 +1388,7 @@ BOOST_AUTO_TEST_CASE(SerializeGather)
armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
BOOST_CHECK(deserializedNetwork);
- GatherLayerVerifier verifier(layerName, {paramsInfo, indicesInfo}, {outputInfo});
+ GatherLayerVerifier verifier(layerName, {paramsInfo, indicesInfo}, {outputInfo}, descriptor);
deserializedNetwork->Accept(verifier);
}