aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/AsyncNetwork.hpp
blob: 9c525c54721d281a20477dd8406bb6f0885bf395 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
//
// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include <armnn/IAsyncNetwork.hpp>
#include <armnn/Tensor.hpp>
#include <armnn/Types.hpp>

#include "LayerFwd.hpp"
#include "Network.hpp"
#include "Profiling.hpp"
#include "WorkingMemHandle.hpp"

#include <armnn/backends/IBackendInternal.hpp>
#include <backendsCommon/TensorHandleFactoryRegistry.hpp>
#include <backendsCommon/Workload.hpp>
#include <backendsCommon/WorkloadFactory.hpp>
#include <ProfilingService.hpp>
#include <TimelineUtilityMethods.hpp>

#include <unordered_map>

namespace armnn
{

namespace experimental
{

class AsyncNetwork final : public IAsyncNetwork
{
public:
    using WorkloadQueue = std::vector<std::unique_ptr<IWorkload>>;

    AsyncNetwork(std::unique_ptr<IOptimizedNetwork> net,
                 const INetworkProperties &networkProperties,
                 profiling::ProfilingService &profilingService);

    ~AsyncNetwork() { FreeWorkingMemory(); }

    TensorInfo GetInputTensorInfo(LayerBindingId layerId) const override;
    TensorInfo GetOutputTensorInfo(LayerBindingId layerId) const override;

    /// Thread safe execution of the network. Returns once execution is complete.
    /// Will block until this and any other thread using the same workingMem object completes.
    virtual Status Execute(const InputTensors& inputTensors,
                           const OutputTensors& outputTensors,
                           IWorkingMemHandle& workingMemHandle) override;

    /// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have
    /// overlapped Execution by calling this function from different threads.
    std::unique_ptr<IWorkingMemHandle> CreateWorkingMemHandle() override;

    /// Get the profiler used for this network
    std::shared_ptr<IProfiler> GetProfiler() const override;

    /// Register a debug callback function to be used with this network
    void RegisterDebugCallback(const DebugCallbackFunction& func) override;

private:
    void FreeWorkingMemory();

    void CollectInputTensorHandles(std::unordered_map<LayerGuid, std::vector<ITensorHandle*> >& tensorHandles,
                                   std::vector<ITensorHandle*>& inputs,
                                   const armnn::Layer* layer,
                                   const TensorHandleFactoryRegistry& registry,
                                   const bool isMemoryManaged = false);

    void CreateOutputTensorHandles(std::unordered_map<LayerGuid, std::vector<ITensorHandle*> >& tensorHandles,
                                   std::vector<ITensorHandle*>& outputs,
                                   const armnn::Layer* layer,
                                   const TensorHandleFactoryRegistry& registry,
                                   const bool isMemoryManaged = false);

    void EnqueueInput(const BindableLayer& layer, const ConstTensor& inputTensor, WorkingMemHandle& handle);

    void EnqueueOutput(const BindableLayer& layer, const Tensor& outputTensor, WorkingMemHandle& handle);

    using BackendPtrMap = std::unordered_map<BackendId, IBackendInternalUniquePtr>;

    using WorkloadFactoryWithMemoryManager =
            std::pair<IBackendInternal::IWorkloadFactoryPtr, IBackendInternal::IMemoryManagerSharedPtr>;

    using WorkloadFactoryMap = std::unordered_map<BackendId, WorkloadFactoryWithMemoryManager>;

    const IWorkloadFactory& GetWorkloadFactory(const Layer& layer) const;

    BackendPtrMap m_Backends;
    WorkloadFactoryMap m_WorkloadFactories;

    std::unique_ptr<IOptimizedNetwork> m_OptimizedNetwork;
    INetworkProperties m_NetworkProperties;
    WorkloadQueue m_WorkloadQueue;
    std::shared_ptr<IProfiler> m_Profiler;

    TensorHandleFactoryRegistry m_TensorHandleFactoryRegistry;

    /// Profiling Service Instance
    profiling::ProfilingService& m_ProfilingService;
};

} // end experimental namespace

} // end armnn namespace