aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/Network_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/Network_test.cpp')
-rw-r--r--src/armnn/test/Network_test.cpp58
1 files changed, 58 insertions, 0 deletions
diff --git a/src/armnn/test/Network_test.cpp b/src/armnn/test/Network_test.cpp
index 523d47b169..057caa0505 100644
--- a/src/armnn/test/Network_test.cpp
+++ b/src/armnn/test/Network_test.cpp
@@ -29,6 +29,64 @@ bool AreAllLayerInputSlotsConnected(const armnn::IConnectableLayer& layer)
BOOST_AUTO_TEST_SUITE(Network)
+BOOST_AUTO_TEST_CASE(LayerGuids)
+{
+ armnn::Network net;
+ armnn::LayerGuid inputId = net.AddInputLayer(0)->GetGuid();
+ armnn::LayerGuid addId = net.AddAdditionLayer()->GetGuid();
+ armnn::LayerGuid outputId = net.AddOutputLayer(0)->GetGuid();
+
+ BOOST_TEST(inputId != addId);
+ BOOST_TEST(addId != outputId);
+ BOOST_TEST(inputId != outputId);
+}
+
+BOOST_AUTO_TEST_CASE(SerializeToDot)
+{
+ armnn::Network net;
+
+ //define layers
+ auto input = net.AddInputLayer(0);
+ auto add = net.AddAdditionLayer();
+ auto output = net.AddOutputLayer(0);
+
+ // connect layers
+ input->GetOutputSlot(0).Connect(add->GetInputSlot(0));
+ input->GetOutputSlot(0).Connect(add->GetInputSlot(1));
+ add->GetOutputSlot(0).Connect(output->GetInputSlot(0));
+
+ armnn::TensorShape shape({4});
+ armnn::TensorInfo info(shape, armnn::DataType::Float32);
+ input->GetOutputSlot(0).SetTensorInfo(info);
+ add->GetOutputSlot(0).SetTensorInfo(info);
+
+ armnn::DeviceSpec spec;
+ spec.DefaultComputeDevice = armnn::Compute::CpuAcc;
+ armnn::IOptimizedNetworkPtr optimizedNet = armnn::Optimize(net, spec);
+
+ std::ostringstream ss;
+ optimizedNet->SerializeToDot(ss);
+
+ auto inputId = input->GetGuid();
+ auto addId = add->GetGuid();
+ auto outputId = output->GetGuid();
+
+ std::stringstream expected;
+ expected <<
+ "digraph Optimized {\n"
+ " node [shape=\"record\"];\n"
+ " edge [fontsize=8 fontcolor=\"blue\" fontname=\"arial-bold\"];\n"
+ " " << inputId << " [label=\"{Input}\"];\n"
+ " " << addId << " [label=\"{Addition}\"];\n"
+ " " << outputId << " [label=\"{Output}\"];\n"
+ " " << inputId << " -> " << addId << " [label=< [4] >];\n"
+ " " << inputId << " -> " << addId << " [label=< [4] >];\n"
+ " " << addId << " -> " << outputId << " [label=< [4] >];\n"
+ "}\n";
+
+ BOOST_TEST(ss.str() == expected.str());
+}
+
BOOST_AUTO_TEST_CASE(NetworkBasic)
{
armnn::Network net;