aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/LoadedNetwork.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/LoadedNetwork.cpp')
-rw-r--r--src/armnn/LoadedNetwork.cpp32
1 files changed, 11 insertions, 21 deletions
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp
index ce9f76c986..4f73bda832 100644
--- a/src/armnn/LoadedNetwork.cpp
+++ b/src/armnn/LoadedNetwork.cpp
@@ -122,7 +122,7 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr<OptimizedNetwork> net)
const char* const layerName = layer->GetNameStr().length() != 0 ? layer->GetName() : "<Unnamed>";
throw InvalidArgumentException(boost::str(
boost::format("No workload created for layer (name: '%1%' type: '%2%') (compute '%3%')")
- % layerName % static_cast<int>(layer->GetType()) % layer->GetComputeDevice()
+ % layerName % static_cast<int>(layer->GetType()) % layer->GetBackendId().Get()
));
}
@@ -176,27 +176,17 @@ const IWorkloadFactory& LoadedNetwork::GetWorkloadFactory(const Layer& layer) co
{
const IWorkloadFactory* workloadFactory = nullptr;
- switch (layer.GetComputeDevice())
+ if (layer.GetBackendId() == Compute::CpuAcc)
{
- case Compute::CpuAcc:
- {
- workloadFactory = &m_CpuAcc;
- break;
- }
- case Compute::GpuAcc:
- {
- workloadFactory = &m_GpuAcc;
- break;
- }
- case Compute::CpuRef:
- {
- workloadFactory = &m_CpuRef;
- break;
- }
- default:
- {
- break;
- }
+ workloadFactory = &m_CpuAcc;
+ }
+ else if (layer.GetBackendId() == Compute::GpuAcc)
+ {
+ workloadFactory = &m_GpuAcc;
+ }
+ else if (layer.GetBackendId() == Compute::CpuRef)
+ {
+ workloadFactory = &m_CpuRef;
}
BOOST_ASSERT_MSG(workloadFactory, "No workload factory");