diff options
Diffstat (limited to 'src/armnn/test/Network_test.cpp')
-rw-r--r-- | src/armnn/test/Network_test.cpp | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/src/armnn/test/Network_test.cpp b/src/armnn/test/Network_test.cpp index 523d47b169..057caa0505 100644 --- a/src/armnn/test/Network_test.cpp +++ b/src/armnn/test/Network_test.cpp @@ -29,6 +29,64 @@ bool AreAllLayerInputSlotsConnected(const armnn::IConnectableLayer& layer) BOOST_AUTO_TEST_SUITE(Network) +BOOST_AUTO_TEST_CASE(LayerGuids) +{ + armnn::Network net; + armnn::LayerGuid inputId = net.AddInputLayer(0)->GetGuid(); + armnn::LayerGuid addId = net.AddAdditionLayer()->GetGuid(); + armnn::LayerGuid outputId = net.AddOutputLayer(0)->GetGuid(); + + BOOST_TEST(inputId != addId); + BOOST_TEST(addId != outputId); + BOOST_TEST(inputId != outputId); +} + +BOOST_AUTO_TEST_CASE(SerializeToDot) +{ + armnn::Network net; + + //define layers + auto input = net.AddInputLayer(0); + auto add = net.AddAdditionLayer(); + auto output = net.AddOutputLayer(0); + + // connect layers + input->GetOutputSlot(0).Connect(add->GetInputSlot(0)); + input->GetOutputSlot(0).Connect(add->GetInputSlot(1)); + add->GetOutputSlot(0).Connect(output->GetInputSlot(0)); + + armnn::TensorShape shape({4}); + armnn::TensorInfo info(shape, armnn::DataType::Float32); + input->GetOutputSlot(0).SetTensorInfo(info); + add->GetOutputSlot(0).SetTensorInfo(info); + + armnn::DeviceSpec spec; + spec.DefaultComputeDevice = armnn::Compute::CpuAcc; + armnn::IOptimizedNetworkPtr optimizedNet = armnn::Optimize(net, spec); + + std::ostringstream ss; + optimizedNet->SerializeToDot(ss); + + auto inputId = input->GetGuid(); + auto addId = add->GetGuid(); + auto outputId = output->GetGuid(); + + std::stringstream expected; + expected << + "digraph Optimized {\n" + " node [shape=\"record\"];\n" + " edge [fontsize=8 fontcolor=\"blue\" fontname=\"arial-bold\"];\n" + " " << inputId << " [label=\"{Input}\"];\n" + " " << addId << " [label=\"{Addition}\"];\n" + " " << outputId << " [label=\"{Output}\"];\n" + " " << inputId << " -> " << addId << " [label=< [4] >];\n" + " " << inputId << " -> " << addId << " [label=< [4] >];\n" + " " << addId << " -> " << outputId << " [label=< [4] >];\n" + "}\n"; + + BOOST_TEST(ss.str() == expected.str()); +} + BOOST_AUTO_TEST_CASE(NetworkBasic) { armnn::Network net; |