diff options
author | telsoa01 <telmo.soares@arm.com> | 2018-03-09 14:13:49 +0000 |
---|---|---|
committer | telsoa01 <telmo.soares@arm.com> | 2018-03-09 14:13:49 +0000 |
commit | 4fcda0101ec3d110c1d6d7bee5c83416b645528a (patch) | |
tree | c9a70aeb2887006160c1b3d265c27efadb7bdbae /src/armnn/LoadedNetwork.hpp | |
download | armnn-4fcda0101ec3d110c1d6d7bee5c83416b645528a.tar.gz |
Release 18.02
Change-Id: Id3c11dc5ee94ef664374a988fcc6901e9a232fa6
Diffstat (limited to 'src/armnn/LoadedNetwork.hpp')
-rw-r--r-- | src/armnn/LoadedNetwork.hpp | 59 |
1 files changed, 59 insertions, 0 deletions
diff --git a/src/armnn/LoadedNetwork.hpp b/src/armnn/LoadedNetwork.hpp new file mode 100644 index 0000000000..d6af11e779 --- /dev/null +++ b/src/armnn/LoadedNetwork.hpp @@ -0,0 +1,59 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// +#pragma once + +#include "armnn/Tensor.hpp" +#include "armnn/Types.hpp" +#include "Network.hpp" +#include "LayerFwd.hpp" +#include "backends/Workload.hpp" +#include "backends/WorkloadFactory.hpp" + +namespace cl +{ + class Context; + class CommandQueue; + class Device; +} + +namespace armnn +{ + +struct WorkloadFactories; + +class LoadedNetwork +{ +public: + TensorInfo GetInputTensorInfo(LayerBindingId layerId) const; + TensorInfo GetOutputTensorInfo(LayerBindingId layerId) const; + + Status EnqueueWorkload(const InputTensors& inputTensors, const OutputTensors& outputTensors, + const WorkloadFactories& workloadFactories); + + static std::unique_ptr<LoadedNetwork> MakeLoadedNetwork(std::unique_ptr<OptimizedNetwork> net, + const WorkloadFactories& workloadFactories); + +private: + LoadedNetwork(std::unique_ptr<OptimizedNetwork> net, const WorkloadFactories& workloadFactories); + + void EnqueueInput(const BindableLayer& layer, ITensorHandle* tensorHandle, const TensorInfo& tensorInfo, + const WorkloadFactories& workloadFactories); + + void EnqueueOutput(const BindableLayer& layer, ITensorHandle* tensorHandle, + const TensorInfo& tensorInfo, const WorkloadFactories& workloadFactories); + + bool Execute(); + + void TidyWorkloadQueue(size_t numInputs, size_t numOutputs); + + const std::shared_ptr<IWorkloadFactory> GetWorkloadFactory(const Layer& layer, + const WorkloadFactories& workloadFactories) const; + + std::unique_ptr<OptimizedNetwork> m_OptimizedNetwork; + + std::vector< std::unique_ptr<IWorkload> > m_WorkloadQueue; +}; + +} |