aboutsummaryrefslogtreecommitdiff
path: root/src/backends/tosaReference/test/TosaRefOptimizedNetworkTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/tosaReference/test/TosaRefOptimizedNetworkTests.cpp')
-rw-r--r--src/backends/tosaReference/test/TosaRefOptimizedNetworkTests.cpp54
1 files changed, 54 insertions, 0 deletions
diff --git a/src/backends/tosaReference/test/TosaRefOptimizedNetworkTests.cpp b/src/backends/tosaReference/test/TosaRefOptimizedNetworkTests.cpp
new file mode 100644
index 0000000000..64b6805d2c
--- /dev/null
+++ b/src/backends/tosaReference/test/TosaRefOptimizedNetworkTests.cpp
@@ -0,0 +1,54 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include <armnn/INetwork.hpp>
+
+#include <GraphUtils.hpp>
+#include <TestUtils.hpp>
+
+#include <doctest/doctest.h>
+
+TEST_SUITE("TosaReferenceOptimizedNetwork")
+{
+
+TEST_CASE("SimpleSupportedOptimizedNetwork")
+{
+ armnn::IRuntime::CreationOptions options;
+ armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
+ armnn::INetworkPtr network(armnn::INetwork::Create());
+
+ auto inputLayer1 = network->AddInputLayer(0, "input_1");
+ auto inputLayer2 = network->AddInputLayer(1, "input_2");
+ auto addLayer = network->AddAdditionLayer("add");
+ auto outputLayer = network->AddOutputLayer(2, "output");
+
+ armnn::TensorInfo tensorInfo{{4}, armnn::DataType::Float32};
+
+ inputLayer1->GetOutputSlot(0).Connect(addLayer->GetInputSlot(0));
+ inputLayer1->GetOutputSlot(0).SetTensorInfo(tensorInfo);
+
+ inputLayer2->GetOutputSlot(0).Connect(addLayer->GetInputSlot(1));
+ inputLayer2->GetOutputSlot(0).SetTensorInfo(tensorInfo);
+
+ addLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
+ addLayer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
+
+ std::vector<armnn::BackendId> backends = { "TosaRef" };
+
+ armnn::OptimizerOptions optimizedOptions;
+ armnn::IOptimizedNetworkPtr optNet = Optimize(*network, backends, runtime->GetDeviceSpec(), optimizedOptions);
+ CHECK(optNet);
+
+ armnn::Graph& graph = GetGraphForTesting(optNet.get());
+
+ // Check graph layer sequence to ensure that the network has been replaced with a PreCompiledLayer
+ CHECK(CheckSequence(graph.cbegin(), graph.cend(),
+ &IsLayerOfType<armnn::InputLayer>,
+ &IsLayerOfType<armnn::InputLayer>,
+ &IsLayerOfType<armnn::PreCompiledLayer>,
+ &IsLayerOfType<armnn::OutputLayer>));
+}
+
+}