aboutsummaryrefslogtreecommitdiff
path: root/src/armnnSerializer/test/SerializerTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnSerializer/test/SerializerTests.cpp')
-rw-r--r--src/armnnSerializer/test/SerializerTests.cpp21
1 files changed, 14 insertions, 7 deletions
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp
index fa43e09647..088282a18a 100644
--- a/src/armnnSerializer/test/SerializerTests.cpp
+++ b/src/armnnSerializer/test/SerializerTests.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -1333,17 +1333,22 @@ BOOST_AUTO_TEST_CASE(SerializeFullyConnected)
BOOST_AUTO_TEST_CASE(SerializeGather)
{
- class GatherLayerVerifier : public LayerVerifierBase
+ using GatherDescriptor = armnn::GatherDescriptor;
+ class GatherLayerVerifier : public LayerVerifierBaseWithDescriptor<GatherDescriptor>
{
public:
GatherLayerVerifier(const std::string& layerName,
const std::vector<armnn::TensorInfo>& inputInfos,
- const std::vector<armnn::TensorInfo>& outputInfos)
- : LayerVerifierBase(layerName, inputInfos, outputInfos) {}
+ const std::vector<armnn::TensorInfo>& outputInfos,
+ const GatherDescriptor& descriptor)
+ : LayerVerifierBaseWithDescriptor<GatherDescriptor>(layerName, inputInfos, outputInfos, descriptor) {}
- void VisitGatherLayer(const armnn::IConnectableLayer* layer, const char *name) override
+ void VisitGatherLayer(const armnn::IConnectableLayer* layer,
+ const GatherDescriptor& descriptor,
+ const char *name) override
{
VerifyNameAndConnections(layer, name);
+ BOOST_CHECK(descriptor.m_Axis == m_Descriptor.m_Axis);
}
void VisitConstantLayer(const armnn::IConnectableLayer*,
@@ -1355,6 +1360,8 @@ BOOST_AUTO_TEST_CASE(SerializeGather)
armnn::TensorInfo paramsInfo({ 8 }, armnn::DataType::QAsymmU8);
armnn::TensorInfo outputInfo({ 3 }, armnn::DataType::QAsymmU8);
const armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32);
+ GatherDescriptor descriptor;
+ descriptor.m_Axis = 1;
paramsInfo.SetQuantizationScale(1.0f);
paramsInfo.SetQuantizationOffset(0);
@@ -1367,7 +1374,7 @@ BOOST_AUTO_TEST_CASE(SerializeGather)
armnn::IConnectableLayer *const inputLayer = network->AddInputLayer(0);
armnn::IConnectableLayer *const constantLayer =
network->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData));
- armnn::IConnectableLayer *const gatherLayer = network->AddGatherLayer(layerName.c_str());
+ armnn::IConnectableLayer *const gatherLayer = network->AddGatherLayer(descriptor, layerName.c_str());
armnn::IConnectableLayer *const outputLayer = network->AddOutputLayer(0);
inputLayer->GetOutputSlot(0).Connect(gatherLayer->GetInputSlot(0));
@@ -1381,7 +1388,7 @@ BOOST_AUTO_TEST_CASE(SerializeGather)
armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
BOOST_CHECK(deserializedNetwork);
- GatherLayerVerifier verifier(layerName, {paramsInfo, indicesInfo}, {outputInfo});
+ GatherLayerVerifier verifier(layerName, {paramsInfo, indicesInfo}, {outputInfo}, descriptor);
deserializedNetwork->Accept(verifier);
}