aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/LoadedNetwork.hpp
diff options
context:
space:
mode:
authortelsoa01 <telmo.soares@arm.com>2018-03-09 14:13:49 +0000
committertelsoa01 <telmo.soares@arm.com>2018-03-09 14:13:49 +0000
commit4fcda0101ec3d110c1d6d7bee5c83416b645528a (patch)
treec9a70aeb2887006160c1b3d265c27efadb7bdbae /src/armnn/LoadedNetwork.hpp
downloadarmnn-4fcda0101ec3d110c1d6d7bee5c83416b645528a.tar.gz
Release 18.02
Change-Id: Id3c11dc5ee94ef664374a988fcc6901e9a232fa6
Diffstat (limited to 'src/armnn/LoadedNetwork.hpp')
-rw-r--r--src/armnn/LoadedNetwork.hpp59
1 files changed, 59 insertions, 0 deletions
diff --git a/src/armnn/LoadedNetwork.hpp b/src/armnn/LoadedNetwork.hpp
new file mode 100644
index 0000000000..d6af11e779
--- /dev/null
+++ b/src/armnn/LoadedNetwork.hpp
@@ -0,0 +1,59 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+#pragma once
+
+#include "armnn/Tensor.hpp"
+#include "armnn/Types.hpp"
+#include "Network.hpp"
+#include "LayerFwd.hpp"
+#include "backends/Workload.hpp"
+#include "backends/WorkloadFactory.hpp"
+
+namespace cl
+{
+ class Context;
+ class CommandQueue;
+ class Device;
+}
+
+namespace armnn
+{
+
+struct WorkloadFactories;
+
+class LoadedNetwork
+{
+public:
+ TensorInfo GetInputTensorInfo(LayerBindingId layerId) const;
+ TensorInfo GetOutputTensorInfo(LayerBindingId layerId) const;
+
+ Status EnqueueWorkload(const InputTensors& inputTensors, const OutputTensors& outputTensors,
+ const WorkloadFactories& workloadFactories);
+
+ static std::unique_ptr<LoadedNetwork> MakeLoadedNetwork(std::unique_ptr<OptimizedNetwork> net,
+ const WorkloadFactories& workloadFactories);
+
+private:
+ LoadedNetwork(std::unique_ptr<OptimizedNetwork> net, const WorkloadFactories& workloadFactories);
+
+ void EnqueueInput(const BindableLayer& layer, ITensorHandle* tensorHandle, const TensorInfo& tensorInfo,
+ const WorkloadFactories& workloadFactories);
+
+ void EnqueueOutput(const BindableLayer& layer, ITensorHandle* tensorHandle,
+ const TensorInfo& tensorInfo, const WorkloadFactories& workloadFactories);
+
+ bool Execute();
+
+ void TidyWorkloadQueue(size_t numInputs, size_t numOutputs);
+
+ const std::shared_ptr<IWorkloadFactory> GetWorkloadFactory(const Layer& layer,
+ const WorkloadFactories& workloadFactories) const;
+
+ std::unique_ptr<OptimizedNetwork> m_OptimizedNetwork;
+
+ std::vector< std::unique_ptr<IWorkload> > m_WorkloadQueue;
+};
+
+}