diff options
Diffstat (limited to 'src/backends/tosaReference/test/TosaRefOptimizedNetworkTests.cpp')
-rw-r--r-- | src/backends/tosaReference/test/TosaRefOptimizedNetworkTests.cpp | 54 |
1 files changed, 54 insertions, 0 deletions
diff --git a/src/backends/tosaReference/test/TosaRefOptimizedNetworkTests.cpp b/src/backends/tosaReference/test/TosaRefOptimizedNetworkTests.cpp new file mode 100644 index 0000000000..64b6805d2c --- /dev/null +++ b/src/backends/tosaReference/test/TosaRefOptimizedNetworkTests.cpp @@ -0,0 +1,54 @@ +// +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include <armnn/INetwork.hpp> + +#include <GraphUtils.hpp> +#include <TestUtils.hpp> + +#include <doctest/doctest.h> + +TEST_SUITE("TosaReferenceOptimizedNetwork") +{ + +TEST_CASE("SimpleSupportedOptimizedNetwork") +{ + armnn::IRuntime::CreationOptions options; + armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options)); + armnn::INetworkPtr network(armnn::INetwork::Create()); + + auto inputLayer1 = network->AddInputLayer(0, "input_1"); + auto inputLayer2 = network->AddInputLayer(1, "input_2"); + auto addLayer = network->AddAdditionLayer("add"); + auto outputLayer = network->AddOutputLayer(2, "output"); + + armnn::TensorInfo tensorInfo{{4}, armnn::DataType::Float32}; + + inputLayer1->GetOutputSlot(0).Connect(addLayer->GetInputSlot(0)); + inputLayer1->GetOutputSlot(0).SetTensorInfo(tensorInfo); + + inputLayer2->GetOutputSlot(0).Connect(addLayer->GetInputSlot(1)); + inputLayer2->GetOutputSlot(0).SetTensorInfo(tensorInfo); + + addLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + addLayer->GetOutputSlot(0).SetTensorInfo(tensorInfo); + + std::vector<armnn::BackendId> backends = { "TosaRef" }; + + armnn::OptimizerOptions optimizedOptions; + armnn::IOptimizedNetworkPtr optNet = Optimize(*network, backends, runtime->GetDeviceSpec(), optimizedOptions); + CHECK(optNet); + + armnn::Graph& graph = GetGraphForTesting(optNet.get()); + + // Check graph layer sequence to ensure that the network has been replaced with a PreCompiledLayer + CHECK(CheckSequence(graph.cbegin(), graph.cend(), + &IsLayerOfType<armnn::InputLayer>, + &IsLayerOfType<armnn::InputLayer>, + &IsLayerOfType<armnn::PreCompiledLayer>, + &IsLayerOfType<armnn::OutputLayer>)); +} + +} |