diff options
Diffstat (limited to 'src/armnnSerializer/test')
-rw-r--r-- | src/armnnSerializer/test/SerializerTests.cpp | 21 |
1 files changed, 14 insertions, 7 deletions
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index fa43e09647..088282a18a 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -1333,17 +1333,22 @@ BOOST_AUTO_TEST_CASE(SerializeFullyConnected) BOOST_AUTO_TEST_CASE(SerializeGather) { - class GatherLayerVerifier : public LayerVerifierBase + using GatherDescriptor = armnn::GatherDescriptor; + class GatherLayerVerifier : public LayerVerifierBaseWithDescriptor<GatherDescriptor> { public: GatherLayerVerifier(const std::string& layerName, const std::vector<armnn::TensorInfo>& inputInfos, - const std::vector<armnn::TensorInfo>& outputInfos) - : LayerVerifierBase(layerName, inputInfos, outputInfos) {} + const std::vector<armnn::TensorInfo>& outputInfos, + const GatherDescriptor& descriptor) + : LayerVerifierBaseWithDescriptor<GatherDescriptor>(layerName, inputInfos, outputInfos, descriptor) {} - void VisitGatherLayer(const armnn::IConnectableLayer* layer, const char *name) override + void VisitGatherLayer(const armnn::IConnectableLayer* layer, + const GatherDescriptor& descriptor, + const char *name) override { VerifyNameAndConnections(layer, name); + BOOST_CHECK(descriptor.m_Axis == m_Descriptor.m_Axis); } void VisitConstantLayer(const armnn::IConnectableLayer*, @@ -1355,6 +1360,8 @@ BOOST_AUTO_TEST_CASE(SerializeGather) armnn::TensorInfo paramsInfo({ 8 }, armnn::DataType::QAsymmU8); armnn::TensorInfo outputInfo({ 3 }, armnn::DataType::QAsymmU8); const armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32); + GatherDescriptor descriptor; + descriptor.m_Axis = 1; paramsInfo.SetQuantizationScale(1.0f); paramsInfo.SetQuantizationOffset(0); @@ -1367,7 +1374,7 @@ BOOST_AUTO_TEST_CASE(SerializeGather) armnn::IConnectableLayer *const inputLayer = network->AddInputLayer(0); armnn::IConnectableLayer *const constantLayer = network->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData)); - armnn::IConnectableLayer *const gatherLayer = network->AddGatherLayer(layerName.c_str()); + armnn::IConnectableLayer *const gatherLayer = network->AddGatherLayer(descriptor, layerName.c_str()); armnn::IConnectableLayer *const outputLayer = network->AddOutputLayer(0); inputLayer->GetOutputSlot(0).Connect(gatherLayer->GetInputSlot(0)); @@ -1381,7 +1388,7 @@ BOOST_AUTO_TEST_CASE(SerializeGather) armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); BOOST_CHECK(deserializedNetwork); - GatherLayerVerifier verifier(layerName, {paramsInfo, indicesInfo}, {outputInfo}); + GatherLayerVerifier verifier(layerName, {paramsInfo, indicesInfo}, {outputInfo}, descriptor); deserializedNetwork->Accept(verifier); } |