aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/LoadedNetwork.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/LoadedNetwork.hpp')
-rw-r--r--src/armnn/LoadedNetwork.hpp26
1 files changed, 13 insertions, 13 deletions
diff --git a/src/armnn/LoadedNetwork.hpp b/src/armnn/LoadedNetwork.hpp
index d6af11e779..79a0b267e9 100644
--- a/src/armnn/LoadedNetwork.hpp
+++ b/src/armnn/LoadedNetwork.hpp
@@ -8,6 +8,9 @@
#include "armnn/Types.hpp"
#include "Network.hpp"
#include "LayerFwd.hpp"
+#include "backends/RefWorkloadFactory.hpp"
+#include "backends/NeonWorkloadFactory.hpp"
+#include "backends/ClWorkloadFactory.hpp"
#include "backends/Workload.hpp"
#include "backends/WorkloadFactory.hpp"
@@ -21,38 +24,35 @@ namespace cl
namespace armnn
{
-struct WorkloadFactories;
-
class LoadedNetwork
{
public:
TensorInfo GetInputTensorInfo(LayerBindingId layerId) const;
TensorInfo GetOutputTensorInfo(LayerBindingId layerId) const;
- Status EnqueueWorkload(const InputTensors& inputTensors, const OutputTensors& outputTensors,
- const WorkloadFactories& workloadFactories);
+ Status EnqueueWorkload(const InputTensors& inputTensors, const OutputTensors& outputTensors);
static std::unique_ptr<LoadedNetwork> MakeLoadedNetwork(std::unique_ptr<OptimizedNetwork> net,
- const WorkloadFactories& workloadFactories);
+ bool useCpuRefAsFallback);
private:
- LoadedNetwork(std::unique_ptr<OptimizedNetwork> net, const WorkloadFactories& workloadFactories);
+ LoadedNetwork(std::unique_ptr<OptimizedNetwork> net, bool useCpuRefAsFallback);
- void EnqueueInput(const BindableLayer& layer, ITensorHandle* tensorHandle, const TensorInfo& tensorInfo,
- const WorkloadFactories& workloadFactories);
+ void EnqueueInput(const BindableLayer& layer, ITensorHandle* tensorHandle, const TensorInfo& tensorInfo);
- void EnqueueOutput(const BindableLayer& layer, ITensorHandle* tensorHandle,
- const TensorInfo& tensorInfo, const WorkloadFactories& workloadFactories);
+ void EnqueueOutput(const BindableLayer& layer, ITensorHandle* tensorHandle, const TensorInfo& tensorInfo);
bool Execute();
void TidyWorkloadQueue(size_t numInputs, size_t numOutputs);
- const std::shared_ptr<IWorkloadFactory> GetWorkloadFactory(const Layer& layer,
- const WorkloadFactories& workloadFactories) const;
+ const IWorkloadFactory& GetWorkloadFactory(const Layer& layer) const;
- std::unique_ptr<OptimizedNetwork> m_OptimizedNetwork;
+ RefWorkloadFactory m_CpuRef;
+ NeonWorkloadFactory m_CpuAcc;
+ ClWorkloadFactory m_GpuAcc;
+ std::unique_ptr<OptimizedNetwork> m_OptimizedNetwork;
std::vector< std::unique_ptr<IWorkload> > m_WorkloadQueue;
};