diff options
Diffstat (limited to 'src/backends/reference/RefTensorHandle.hpp')
-rw-r--r-- | src/backends/reference/RefTensorHandle.hpp | 88 |
1 files changed, 85 insertions, 3 deletions
diff --git a/src/backends/reference/RefTensorHandle.hpp b/src/backends/reference/RefTensorHandle.hpp index b4dedd5e77..128f623cd3 100644 --- a/src/backends/reference/RefTensorHandle.hpp +++ b/src/backends/reference/RefTensorHandle.hpp @@ -1,7 +1,8 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2019-2023 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // + #pragma once #include <armnn/backends/TensorHandle.hpp> @@ -11,14 +12,17 @@ namespace armnn { +class RefTensorHandleDecorator; // An implementation of ITensorHandle with simple "bump the pointer" memory-management behaviour class RefTensorHandle : public ITensorHandle { public: - RefTensorHandle(const TensorInfo& tensorInfo, std::shared_ptr<RefMemoryManager> &memoryManager); + RefTensorHandle(const TensorInfo& tensorInfo, std::shared_ptr<RefMemoryManager>& memoryManager); RefTensorHandle(const TensorInfo& tensorInfo); + RefTensorHandle(const TensorInfo& tensorInfo, const RefTensorHandle& parent); + ~RefTensorHandle(); virtual void Manage() override; @@ -56,6 +60,8 @@ public: virtual bool Import(void* memory, MemorySource source) override; virtual bool CanBeImported(void* memory, MemorySource source) override; + virtual std::shared_ptr<ITensorHandle> DecorateTensorHandle(const TensorInfo& tensorInfo) override; + private: // Only used for testing void CopyOutTo(void*) const override; @@ -68,10 +74,86 @@ private: TensorInfo m_TensorInfo; - std::shared_ptr<RefMemoryManager> m_MemoryManager; + mutable std::shared_ptr<RefMemoryManager> m_MemoryManager; RefMemoryManager::Pool* m_Pool; mutable void* m_UnmanagedMemory; void* m_ImportedMemory; + std::vector<std::shared_ptr<RefTensorHandleDecorator>> m_Decorated; +}; + +class RefTensorHandleDecorator : public RefTensorHandle +{ +public: + RefTensorHandleDecorator(const TensorInfo& tensorInfo, const RefTensorHandle& parent); + + ~RefTensorHandleDecorator() = default; + + virtual void Manage() override; + + virtual void Allocate() override; + + virtual ITensorHandle* GetParent() const override + { + return nullptr; + } + + virtual const void* Map(bool /* blocking = true */) const override; + using ITensorHandle::Map; + + virtual void Unmap() const override + {} + + TensorShape GetStrides() const override + { + return GetUnpaddedTensorStrides(m_TensorInfo); + } + + TensorShape GetShape() const override + { + return m_TensorInfo.GetShape(); + } + + const TensorInfo& GetTensorInfo() const + { + return m_TensorInfo; + } + + virtual MemorySourceFlags GetImportFlags() const override; + + virtual bool Import(void* memory, MemorySource source) override; + virtual bool CanBeImported(void* memory, MemorySource source) override; + + virtual std::shared_ptr<ITensorHandle> DecorateTensorHandle(const TensorInfo& tensorInfo) override; + + /// Map the tensor data for access. Must be paired with call to Unmap(). + /// \param blocking hint to block the calling thread until all other accesses are complete. (backend dependent) + /// \return pointer to the first element of the mapped data. + void* Map(bool blocking=true) + { + return const_cast<void*>(static_cast<const ITensorHandle*>(this)->Map(blocking)); + } + + /// Unmap the tensor data that was previously mapped with call to Map(). + void Unmap() + { + return static_cast<const ITensorHandle*>(this)->Unmap(); + } + + /// Testing support to be able to verify and set tensor data content + void CopyOutTo(void* /* memory */) const override + {}; + + void CopyInFrom(const void* /* memory */) override + {}; + + /// Unimport externally allocated memory + void Unimport() override + {}; + +private: + TensorInfo m_TensorInfo; + const RefTensorHandle& m_Parent; }; } + |