diff options
Diffstat (limited to 'src/armnn/Layer.hpp')
-rw-r--r-- | src/armnn/Layer.hpp | 20 |
1 files changed, 19 insertions, 1 deletions
diff --git a/src/armnn/Layer.hpp b/src/armnn/Layer.hpp index cbb1771668..1ddbc00bc7 100644 --- a/src/armnn/Layer.hpp +++ b/src/armnn/Layer.hpp @@ -6,7 +6,9 @@ #include "LayerFwd.hpp" +#include <backendsCommon/ITensorHandleFactory.hpp> #include <backendsCommon/OutputHandler.hpp> +#include <backendsCommon/TensorHandleFactoryRegistry.hpp> #include <backendsCommon/WorkloadDataCollector.hpp> #include <backendsCommon/WorkloadInfo.hpp> #include "InternalTypes.hpp" @@ -84,8 +86,15 @@ public: explicit OutputSlot(Layer& owner, OutputHandler& outputHandler) : m_OwningLayer(owner) , m_OutputHandler(outputHandler) + , m_TensorHandleFactoryId(ITensorHandleFactory::LegacyFactoryId) {} + OutputSlot(const OutputSlot&) = delete; + OutputSlot& operator=(const OutputSlot&) = delete; + + OutputSlot(OutputSlot&&) = default; + OutputSlot& operator=(OutputSlot&&) = default; + ~OutputSlot() { try @@ -147,12 +156,21 @@ public: bool operator==(const OutputSlot& other) const; + void SetTensorHandleFactory(const ITensorHandleFactory::FactoryId& id); + ITensorHandleFactory::FactoryId GetTensorHandleFactoryId() const; + + void SetMemoryStrategy(unsigned int connectionIndex, MemoryStrategy strategy); + MemoryStrategy GetMemoryStrategyForConnection(unsigned int connectionIdx) const; + private: void ValidateConnectionIndex(unsigned int index) const; Layer& m_OwningLayer; OutputHandler& m_OutputHandler; std::vector<InputSlot*> m_Connections; + + ITensorHandleFactory::FactoryId m_TensorHandleFactoryId; + std::vector<MemoryStrategy> m_MemoryStrategies; }; // InputSlot inlines that need OutputSlot declaration. @@ -248,7 +266,7 @@ public: virtual std::unique_ptr<IWorkload> CreateWorkload(const Graph& graph, const IWorkloadFactory& factory) const = 0; - virtual void CreateTensorHandles(Graph& graph, const IWorkloadFactory& factory); + virtual void CreateTensorHandles(const TensorHandleFactoryRegistry& registry, const IWorkloadFactory& factory); /// Creates a dynamically-allocated copy of this layer. /// @param graph - The Graph into which this Layer is being cloned. |