From dcf041cfd67acff7ebd524008050b4e1a435c0e5 Mon Sep 17 00:00:00 2001 From: Mike Kelly Date: Mon, 20 Jan 2020 17:18:18 +0000 Subject: IVGCVSW-4331 Calling RemoveDebugLayers can break connections * Changed RemoveDebugLayers to move all connections from its OutputSlot. Signed-off-by: Mike Kelly Change-Id: I3c649e3f660804ca48f3c2af993a5af6a7ed4d4a --- src/armnn/DynamicQuantizationVisitor.cpp | 9 ++-- src/armnn/test/QuantizerTest.cpp | 80 ++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 4 deletions(-) diff --git a/src/armnn/DynamicQuantizationVisitor.cpp b/src/armnn/DynamicQuantizationVisitor.cpp index ba87c6d335..4b1dce0b6f 100644 --- a/src/armnn/DynamicQuantizationVisitor.cpp +++ b/src/armnn/DynamicQuantizationVisitor.cpp @@ -63,13 +63,14 @@ void DynamicQuantizationVisitor::RemoveDebugLayers() for (DebugLayer* debugLayer : m_DebugLayers) { OutputSlot& proceedingOutputSlot = *debugLayer->GetInputSlot(0).GetConnectedOutputSlot(); - InputSlot& succeedingInputSlot = *debugLayer->GetOutputSlot(0).GetConnection(0); proceedingOutputSlot.Disconnect(debugLayer->GetInputSlot(0)); - debugLayer->GetOutputSlot(0).Disconnect(succeedingInputSlot); + for (InputSlot* succeedingInputSlot : debugLayer->GetOutputSlot(0).GetConnections()) + { + debugLayer->GetOutputSlot(0).Disconnect(*succeedingInputSlot); + proceedingOutputSlot.Connect(*succeedingInputSlot); + } m_Graph.EraseLayer(debugLayer); - - proceedingOutputSlot.Connect(succeedingInputSlot); } m_DebugLayers.clear(); } diff --git a/src/armnn/test/QuantizerTest.cpp b/src/armnn/test/QuantizerTest.cpp index 900aa1813e..52beb630f9 100644 --- a/src/armnn/test/QuantizerTest.cpp +++ b/src/armnn/test/QuantizerTest.cpp @@ -2708,5 +2708,85 @@ BOOST_AUTO_TEST_CASE(PreserveTypeQsymm16) PreserveTypeTestImpl(DataType::QSymmS16); } +BOOST_AUTO_TEST_CASE(TestConnectionPreservationAfterDynamicQuant) +{ + class TestConnectionPreservation : public LayerVisitorBase + { + public: + TestConnectionPreservation(const Graph& graph) + : LayerVisitorBase() + , m_Graph(graph) + {} + + void VisitAdditionLayer(const IConnectableLayer* layer, const char*) override + { + CheckLayerName(layer->GetInputSlot(0).GetConnection()->GetOwningLayerGuid(), "reLU1"); + CheckLayerName(layer->GetInputSlot(1).GetConnection()->GetOwningLayerGuid(), "reLU2"); + } + + void CheckLayerName(LayerGuid guid, std::string expectedName) + { + bool guidFound = false; + for (Layer* layer : m_Graph) + { + if (layer->GetGuid() == guid) + { + BOOST_CHECK_EQUAL(layer->GetName(), expectedName.c_str()); + guidFound = true; + break; + } + } + if (!guidFound) + { + BOOST_FAIL("No layer matching the GUID was found"); + } + } + + private: + Graph m_Graph; + }; + + INetworkPtr network = INetwork::Create(); + + IConnectableLayer* inputLayer = network->AddInputLayer(0,"inputLayer1"); + armnn::ActivationDescriptor ReLUDesc; + ReLUDesc.m_Function = ActivationFunction::ReLu; + + IConnectableLayer* reLULayer1 = network->AddActivationLayer(ReLUDesc, "reLU1"); + IConnectableLayer* reLULayer2 = network->AddActivationLayer(ReLUDesc, "reLU2"); + IConnectableLayer* addLayer1 = network->AddAdditionLayer("addLayer1"); + IConnectableLayer* outputLayer = network->AddOutputLayer(0,"outPutLayer1"); + + inputLayer->GetOutputSlot(0).Connect(reLULayer1->GetInputSlot(0)); + reLULayer1->GetOutputSlot(0).Connect(reLULayer2->GetInputSlot(0)); + reLULayer1->GetOutputSlot(0).Connect(addLayer1->GetInputSlot(0)); + reLULayer2->GetOutputSlot(0).Connect(addLayer1->GetInputSlot(1)); + addLayer1->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + + inputLayer->GetOutputSlot(0).SetTensorInfo(TensorInfo(TensorShape({1, 2, 2, 1}), DataType::Float32)); + reLULayer1->GetOutputSlot(0).SetTensorInfo(TensorInfo(TensorShape({1, 2, 2, 1}), DataType::Float32)); + reLULayer2->GetOutputSlot(0).SetTensorInfo(TensorInfo(TensorShape({1, 2, 2, 1}), DataType::Float32)); + addLayer1->GetOutputSlot(0).SetTensorInfo(TensorInfo(TensorShape({1, 2, 2, 1}), DataType::Float32)); + + TestConnectionPreservation visitor1(boost::polymorphic_downcast(network.get())->GetGraph()); + VisitLayersTopologically(network.get(), visitor1); + + armnn::INetworkQuantizerPtr quantizer = armnn::INetworkQuantizer::Create(network.get()); + + armnn::TensorInfo tensorInfo = GetInputTensorInfo(boost::polymorphic_downcast(network.get())); + + std::vector inputData({0, 2, 0, 4}); + armnn::ConstTensor inputTensor(tensorInfo, inputData.data()); + + InputTensors inputTensors; + inputTensors.push_back(std::make_pair(0, inputTensor)); + quantizer->Refine(inputTensors); + + INetworkPtr quantNetwork = quantizer->ExportNetwork(); + + TestConnectionPreservation visitor2(boost::polymorphic_downcast(quantNetwork.get())->GetGraph()); + VisitLayersTopologically(quantNetwork.get(), visitor2); +} + BOOST_AUTO_TEST_SUITE_END() } // namespace armnn -- cgit v1.2.1