From 526647333571169076f5e72c9fb18c71025bf7c0 Mon Sep 17 00:00:00 2001 From: Teresa Charlin Date: Mon, 29 Jun 2020 16:27:03 +0100 Subject: IVGCVSW-4903 Connect axis parameter in Gather from android to ACL. !android-nn-driver:3302 Signed-off-by: Teresa Charlin Change-Id: Ifbc49acb5272f8a36719bb68676e44817190537d --- src/armnnSerializer/test/SerializerTests.cpp | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) (limited to 'src/armnnSerializer/test/SerializerTests.cpp') diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index fa43e09647..088282a18a 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -1333,17 +1333,22 @@ BOOST_AUTO_TEST_CASE(SerializeFullyConnected) BOOST_AUTO_TEST_CASE(SerializeGather) { - class GatherLayerVerifier : public LayerVerifierBase + using GatherDescriptor = armnn::GatherDescriptor; + class GatherLayerVerifier : public LayerVerifierBaseWithDescriptor { public: GatherLayerVerifier(const std::string& layerName, const std::vector& inputInfos, - const std::vector& outputInfos) - : LayerVerifierBase(layerName, inputInfos, outputInfos) {} + const std::vector& outputInfos, + const GatherDescriptor& descriptor) + : LayerVerifierBaseWithDescriptor(layerName, inputInfos, outputInfos, descriptor) {} - void VisitGatherLayer(const armnn::IConnectableLayer* layer, const char *name) override + void VisitGatherLayer(const armnn::IConnectableLayer* layer, + const GatherDescriptor& descriptor, + const char *name) override { VerifyNameAndConnections(layer, name); + BOOST_CHECK(descriptor.m_Axis == m_Descriptor.m_Axis); } void VisitConstantLayer(const armnn::IConnectableLayer*, @@ -1355,6 +1360,8 @@ BOOST_AUTO_TEST_CASE(SerializeGather) armnn::TensorInfo paramsInfo({ 8 }, armnn::DataType::QAsymmU8); armnn::TensorInfo outputInfo({ 3 }, armnn::DataType::QAsymmU8); const armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32); + GatherDescriptor descriptor; + descriptor.m_Axis = 1; paramsInfo.SetQuantizationScale(1.0f); paramsInfo.SetQuantizationOffset(0); @@ -1367,7 +1374,7 @@ BOOST_AUTO_TEST_CASE(SerializeGather) armnn::IConnectableLayer *const inputLayer = network->AddInputLayer(0); armnn::IConnectableLayer *const constantLayer = network->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData)); - armnn::IConnectableLayer *const gatherLayer = network->AddGatherLayer(layerName.c_str()); + armnn::IConnectableLayer *const gatherLayer = network->AddGatherLayer(descriptor, layerName.c_str()); armnn::IConnectableLayer *const outputLayer = network->AddOutputLayer(0); inputLayer->GetOutputSlot(0).Connect(gatherLayer->GetInputSlot(0)); @@ -1381,7 +1388,7 @@ BOOST_AUTO_TEST_CASE(SerializeGather) armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); BOOST_CHECK(deserializedNetwork); - GatherLayerVerifier verifier(layerName, {paramsInfo, indicesInfo}, {outputInfo}); + GatherLayerVerifier verifier(layerName, {paramsInfo, indicesInfo}, {outputInfo}, descriptor); deserializedNetwork->Accept(verifier); } -- cgit v1.2.1