From f81edaacc38d6edc4a2dc230460120c6f83e0cda Mon Sep 17 00:00:00 2001 From: Matteo Martincigh Date: Mon, 4 Mar 2019 14:34:30 +0000 Subject: IVGCVSW-2691 Add Serialize/Deseralize Gather layer Change-Id: I445c4475e5abfe500b61ce8b7138c45322043c8b Signed-off-by: Matteo Martincigh --- src/armnnDeserializer/Deserializer.cpp | 3 --- src/armnnSerializer/Serializer.cpp | 8 +++--- src/armnnSerializer/Serializer.hpp | 4 +-- src/armnnSerializer/test/SerializerTests.cpp | 40 ++++++++++++++-------------- 4 files changed, 25 insertions(+), 30 deletions(-) (limited to 'src') diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp index 04c629698a..ac7eae7069 100644 --- a/src/armnnDeserializer/Deserializer.cpp +++ b/src/armnnDeserializer/Deserializer.cpp @@ -1724,16 +1724,13 @@ void Deserializer::ParseGather(GraphPtr graph, unsigned int layerIndex) CHECK_VALID_SIZE(outputs.size(), 1); auto layerName = GetLayerName(graph, layerIndex); - IConnectableLayer* layer = m_Network->AddGatherLayer(layerName.c_str()); armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); - layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); RegisterInputSlots(graph, layerIndex, layer); RegisterOutputSlots(graph, layerIndex, layer); - } } // namespace armnnDeserializer diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index 38e815d9e0..4e408ae807 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -331,10 +331,8 @@ void SerializerVisitor::VisitMaximumLayer(const armnn::IConnectableLayer* layer, void SerializerVisitor::VisitGatherLayer(const armnn::IConnectableLayer* layer, const char* name) { - auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Gather); - - auto flatBufferLayer = CreateGatherLayer(m_flatBufferBuilder, - fbBaseLayer); + auto fbGatherBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Gather); + auto flatBufferLayer = serializer::CreateGatherLayer(m_flatBufferBuilder, fbGatherBaseLayer); CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_GatherLayer); } @@ -342,7 +340,7 @@ void SerializerVisitor::VisitGatherLayer(const armnn::IConnectableLayer* layer, void SerializerVisitor::VisitGreaterLayer(const armnn::IConnectableLayer* layer, const char* name) { auto fbGreaterBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Greater); - auto fbGreaterLayer = serializer::CreateGreaterLayer(m_flatBufferBuilder, fbGreaterBaseLayer); + auto fbGreaterLayer = serializer::CreateGreaterLayer(m_flatBufferBuilder, fbGreaterBaseLayer); CreateAnyLayer(fbGreaterLayer.o, serializer::Layer::Layer_GreaterLayer); } diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp index 1b1a3e998d..5244a3c1fd 100644 --- a/src/armnnSerializer/Serializer.hpp +++ b/src/armnnSerializer/Serializer.hpp @@ -19,7 +19,7 @@ namespace armnnSerializer class SerializerVisitor : public armnn::LayerVisitorBase { public: - SerializerVisitor() : m_layerId(0) {}; + SerializerVisitor() : m_layerId(0) {} ~SerializerVisitor() {} flatbuffers::FlatBufferBuilder& GetFlatBufferBuilder() @@ -94,7 +94,7 @@ public: void VisitGatherLayer(const armnn::IConnectableLayer* layer, const char* name = nullptr) override; - + void VisitGreaterLayer(const armnn::IConnectableLayer* layer, const char* name = nullptr) override; diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index 62fa0c6bd8..0689114074 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -66,8 +66,8 @@ void CheckDeserializedNetworkAgainstOriginal(const armnn::INetwork& deserialized const armnn::INetwork& originalNetwork, const std::vector& inputShapes, const std::vector& outputShapes, - const std::vector& inputBindingIds={0}, - const std::vector& outputBindingIds={0}) + const std::vector& inputBindingIds = {0}, + const std::vector& outputBindingIds = {0}) { BOOST_CHECK(inputShapes.size() == inputBindingIds.size()); BOOST_CHECK(outputShapes.size() == outputBindingIds.size()); @@ -227,12 +227,12 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeConstant) armnn::ConstTensor constTensor(commonTensorInfo, constantData); // Builds up the structure of the network. - armnn::INetworkPtr net(armnn::INetwork::Create()); + armnn::INetworkPtr network(armnn::INetwork::Create()); - armnn::IConnectableLayer* input = net->AddInputLayer(0); - armnn::IConnectableLayer* constant = net->AddConstantLayer(constTensor, "constant"); - armnn::IConnectableLayer* add = net->AddAdditionLayer(); - armnn::IConnectableLayer* output = net->AddOutputLayer(0); + armnn::IConnectableLayer* input = network->AddInputLayer(0); + armnn::IConnectableLayer* constant = network->AddConstantLayer(constTensor, "constant"); + armnn::IConnectableLayer* add = network->AddAdditionLayer(); + armnn::IConnectableLayer* output = network->AddOutputLayer(0); input->GetOutputSlot(0).Connect(add->GetInputSlot(0)); constant->GetOutputSlot(0).Connect(add->GetInputSlot(1)); @@ -243,14 +243,14 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeConstant) constant->GetOutputSlot(0).SetTensorInfo(commonTensorInfo); add->GetOutputSlot(0).SetTensorInfo(commonTensorInfo); - armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*net)); + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); BOOST_CHECK(deserializedNetwork); VerifyConstantName nameChecker; deserializedNetwork->Accept(nameChecker); CheckDeserializedNetworkAgainstOriginal(*deserializedNetwork, - *net, + *network, {commonTensorInfo.GetShape()}, {commonTensorInfo.GetShape()}); } @@ -542,8 +542,8 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeGreater) VerifyGreaterName nameChecker; deserializedNetwork->Accept(nameChecker); - CheckDeserializedNetworkAgainstOriginal(*network, - *deserializedNetwork, + CheckDeserializedNetworkAgainstOriginal(*deserializedNetwork, + *network, {inputTensorInfo1.GetShape(), inputTensorInfo2.GetShape()}, {outputTensorInfo.GetShape()}, {0, 1}); @@ -992,8 +992,8 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeBatchNormalization) VerifyBatchNormalizationName nameChecker; deserializedNetwork->Accept(nameChecker); - CheckDeserializedNetworkAgainstOriginal(*network, - *deserializedNetwork, + CheckDeserializedNetworkAgainstOriginal(*deserializedNetwork, + *network, {inputInfo.GetShape()}, {outputInfo.GetShape()}); } @@ -1127,10 +1127,10 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeEqual) deserializedNetwork->Accept(nameChecker); CheckDeserializedNetworkAgainstOriginal(*deserializedNetwork, - *network, - {inputTensorInfo1.GetShape(), inputTensorInfo2.GetShape()}, - {outputTensorInfo.GetShape()}, - {0, 1}); + *network, + {inputTensorInfo1.GetShape(), inputTensorInfo2.GetShape()}, + {outputTensorInfo.GetShape()}, + {0, 1}); } BOOST_AUTO_TEST_CASE(SerializeDeserializePad) @@ -1168,9 +1168,9 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializePad) deserializedNetwork->Accept(nameChecker); CheckDeserializedNetworkAgainstOriginal(*deserializedNetwork, - *network, - {inputTensorInfo.GetShape()}, - {outputTensorInfo.GetShape()}); + *network, + {inputTensorInfo.GetShape()}, + {outputTensorInfo.GetShape()}); } BOOST_AUTO_TEST_CASE(SerializeRsqrt) -- cgit v1.2.1