aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/test/RefOptimizedNetworkTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/test/RefOptimizedNetworkTests.cpp')
-rw-r--r--src/backends/reference/test/RefOptimizedNetworkTests.cpp19
1 files changed, 13 insertions, 6 deletions
diff --git a/src/backends/reference/test/RefOptimizedNetworkTests.cpp b/src/backends/reference/test/RefOptimizedNetworkTests.cpp
index 16ff202f70..086c1e471a 100644
--- a/src/backends/reference/test/RefOptimizedNetworkTests.cpp
+++ b/src/backends/reference/test/RefOptimizedNetworkTests.cpp
@@ -71,12 +71,13 @@ BOOST_AUTO_TEST_CASE(OptimizeValidateCpuRefWorkloads)
std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
armnn::IOptimizedNetworkPtr optNet = armnn::Optimize(*net, backends, runtime->GetDeviceSpec());
- static_cast<armnn::OptimizedNetwork*>(optNet.get())->GetGraph().AllocateDynamicBuffers();
+ armnn::Graph& graph = GetGraphForTesting(optNet.get());
+ graph.AllocateDynamicBuffers();
BOOST_CHECK(optNet);
// Validates workloads.
armnn::RefWorkloadFactory fact;
- for (auto&& layer : static_cast<armnn::OptimizedNetwork*>(optNet.get())->GetGraph())
+ for (auto&& layer : graph)
{
BOOST_CHECK_NO_THROW(layer->CreateWorkload(fact));
}
@@ -109,7 +110,10 @@ BOOST_AUTO_TEST_CASE(OptimizeValidateWorkloadsCpuRefPermuteLayer)
// optimize the network
armnn::IOptimizedNetworkPtr optNet = armnn::Optimize(*net, backends, runtime->GetDeviceSpec());
- for (auto&& layer : static_cast<armnn::OptimizedNetwork*>(optNet.get())->GetGraph())
+ armnn::Graph& graph = GetGraphForTesting(optNet.get());
+ graph.AllocateDynamicBuffers();
+
+ for (auto&& layer : graph)
{
BOOST_CHECK(layer->GetBackendId() == armnn::Compute::CpuRef);
}
@@ -141,8 +145,9 @@ BOOST_AUTO_TEST_CASE(OptimizeValidateWorkloadsCpuRefMeanLayer)
// optimize the network
armnn::IOptimizedNetworkPtr optNet = armnn::Optimize(*net, backends, runtime->GetDeviceSpec());
-
- for (auto&& layer : static_cast<armnn::OptimizedNetwork*>(optNet.get())->GetGraph())
+ armnn::Graph& graph = GetGraphForTesting(optNet.get());
+ graph.AllocateDynamicBuffers();
+ for (auto&& layer : graph)
{
BOOST_CHECK(layer->GetBackendId() == armnn::Compute::CpuRef);
}
@@ -183,7 +188,9 @@ BOOST_AUTO_TEST_CASE(DebugTestOnCpuRef)
armnn::IOptimizedNetworkPtr optimizedNet = armnn::Optimize(*net, backends, runtime->GetDeviceSpec(),
optimizerOptions);
- const armnn::Graph& graph = static_cast<armnn::OptimizedNetwork*>(optimizedNet.get())->GetGraph();
+ armnn::Graph& graph = GetGraphForTesting(optimizedNet.get());
+ graph.AllocateDynamicBuffers();
+
// Tests that all layers are present in the graph.
BOOST_TEST(graph.GetNumLayers() == 5);