aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/armnn/DynamicQuantizationVisitor.cpp9
-rw-r--r--src/armnn/test/QuantizerTest.cpp80
2 files changed, 85 insertions, 4 deletions
diff --git a/src/armnn/DynamicQuantizationVisitor.cpp b/src/armnn/DynamicQuantizationVisitor.cpp
index ba87c6d335..4b1dce0b6f 100644
--- a/src/armnn/DynamicQuantizationVisitor.cpp
+++ b/src/armnn/DynamicQuantizationVisitor.cpp
@@ -63,13 +63,14 @@ void DynamicQuantizationVisitor::RemoveDebugLayers()
for (DebugLayer* debugLayer : m_DebugLayers)
{
OutputSlot& proceedingOutputSlot = *debugLayer->GetInputSlot(0).GetConnectedOutputSlot();
- InputSlot& succeedingInputSlot = *debugLayer->GetOutputSlot(0).GetConnection(0);
proceedingOutputSlot.Disconnect(debugLayer->GetInputSlot(0));
- debugLayer->GetOutputSlot(0).Disconnect(succeedingInputSlot);
+ for (InputSlot* succeedingInputSlot : debugLayer->GetOutputSlot(0).GetConnections())
+ {
+ debugLayer->GetOutputSlot(0).Disconnect(*succeedingInputSlot);
+ proceedingOutputSlot.Connect(*succeedingInputSlot);
+ }
m_Graph.EraseLayer(debugLayer);
-
- proceedingOutputSlot.Connect(succeedingInputSlot);
}
m_DebugLayers.clear();
}
diff --git a/src/armnn/test/QuantizerTest.cpp b/src/armnn/test/QuantizerTest.cpp
index 900aa1813e..52beb630f9 100644
--- a/src/armnn/test/QuantizerTest.cpp
+++ b/src/armnn/test/QuantizerTest.cpp
@@ -2708,5 +2708,85 @@ BOOST_AUTO_TEST_CASE(PreserveTypeQsymm16)
PreserveTypeTestImpl(DataType::QSymmS16);
}
+BOOST_AUTO_TEST_CASE(TestConnectionPreservationAfterDynamicQuant)
+{
+ class TestConnectionPreservation : public LayerVisitorBase<VisitorNoThrowPolicy>
+ {
+ public:
+ TestConnectionPreservation(const Graph& graph)
+ : LayerVisitorBase<VisitorNoThrowPolicy>()
+ , m_Graph(graph)
+ {}
+
+ void VisitAdditionLayer(const IConnectableLayer* layer, const char*) override
+ {
+ CheckLayerName(layer->GetInputSlot(0).GetConnection()->GetOwningLayerGuid(), "reLU1");
+ CheckLayerName(layer->GetInputSlot(1).GetConnection()->GetOwningLayerGuid(), "reLU2");
+ }
+
+ void CheckLayerName(LayerGuid guid, std::string expectedName)
+ {
+ bool guidFound = false;
+ for (Layer* layer : m_Graph)
+ {
+ if (layer->GetGuid() == guid)
+ {
+ BOOST_CHECK_EQUAL(layer->GetName(), expectedName.c_str());
+ guidFound = true;
+ break;
+ }
+ }
+ if (!guidFound)
+ {
+ BOOST_FAIL("No layer matching the GUID was found");
+ }
+ }
+
+ private:
+ Graph m_Graph;
+ };
+
+ INetworkPtr network = INetwork::Create();
+
+ IConnectableLayer* inputLayer = network->AddInputLayer(0,"inputLayer1");
+ armnn::ActivationDescriptor ReLUDesc;
+ ReLUDesc.m_Function = ActivationFunction::ReLu;
+
+ IConnectableLayer* reLULayer1 = network->AddActivationLayer(ReLUDesc, "reLU1");
+ IConnectableLayer* reLULayer2 = network->AddActivationLayer(ReLUDesc, "reLU2");
+ IConnectableLayer* addLayer1 = network->AddAdditionLayer("addLayer1");
+ IConnectableLayer* outputLayer = network->AddOutputLayer(0,"outPutLayer1");
+
+ inputLayer->GetOutputSlot(0).Connect(reLULayer1->GetInputSlot(0));
+ reLULayer1->GetOutputSlot(0).Connect(reLULayer2->GetInputSlot(0));
+ reLULayer1->GetOutputSlot(0).Connect(addLayer1->GetInputSlot(0));
+ reLULayer2->GetOutputSlot(0).Connect(addLayer1->GetInputSlot(1));
+ addLayer1->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
+
+ inputLayer->GetOutputSlot(0).SetTensorInfo(TensorInfo(TensorShape({1, 2, 2, 1}), DataType::Float32));
+ reLULayer1->GetOutputSlot(0).SetTensorInfo(TensorInfo(TensorShape({1, 2, 2, 1}), DataType::Float32));
+ reLULayer2->GetOutputSlot(0).SetTensorInfo(TensorInfo(TensorShape({1, 2, 2, 1}), DataType::Float32));
+ addLayer1->GetOutputSlot(0).SetTensorInfo(TensorInfo(TensorShape({1, 2, 2, 1}), DataType::Float32));
+
+ TestConnectionPreservation visitor1(boost::polymorphic_downcast<const Network*>(network.get())->GetGraph());
+ VisitLayersTopologically(network.get(), visitor1);
+
+ armnn::INetworkQuantizerPtr quantizer = armnn::INetworkQuantizer::Create(network.get());
+
+ armnn::TensorInfo tensorInfo = GetInputTensorInfo(boost::polymorphic_downcast<const Network*>(network.get()));
+
+ std::vector<float> inputData({0, 2, 0, 4});
+ armnn::ConstTensor inputTensor(tensorInfo, inputData.data());
+
+ InputTensors inputTensors;
+ inputTensors.push_back(std::make_pair(0, inputTensor));
+ quantizer->Refine(inputTensors);
+
+ INetworkPtr quantNetwork = quantizer->ExportNetwork();
+
+ TestConnectionPreservation visitor2(boost::polymorphic_downcast<const Network*>(quantNetwork.get())->GetGraph());
+ VisitLayersTopologically(quantNetwork.get(), visitor2);
+}
+
BOOST_AUTO_TEST_SUITE_END()
} // namespace armnn