// // Copyright © 2021 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include "ITensorHandle.hpp" #include #include #include #include namespace armnn { // Get a TensorShape representing the strides (in bytes) for each dimension // of a tensor, assuming fully packed data with no padding TensorShape GetUnpaddedTensorStrides(const TensorInfo& tensorInfo); // Abstract tensor handles wrapping a readable region of memory, interpreting it as tensor data. class ConstTensorHandle : public ITensorHandle { public: template const T* GetConstTensor() const { if (armnnUtils::CompatibleTypes(GetTensorInfo().GetDataType())) { return reinterpret_cast(m_Memory); } else { throw armnn::Exception("Attempting to get not compatible type tensor!"); } } const TensorInfo& GetTensorInfo() const { return m_TensorInfo; } virtual void Manage() override {} virtual ITensorHandle* GetParent() const override { return nullptr; } virtual const void* Map(bool /* blocking = true */) const override { return m_Memory; } virtual void Unmap() const override {} TensorShape GetStrides() const override { return GetUnpaddedTensorStrides(m_TensorInfo); } TensorShape GetShape() const override { return m_TensorInfo.GetShape(); } protected: ConstTensorHandle(const TensorInfo& tensorInfo); void SetConstMemory(const void* mem) { m_Memory = mem; } private: // Only used for testing void CopyOutTo(void *) const override { ARMNN_ASSERT_MSG(false, "Unimplemented"); } void CopyInFrom(const void*) override { ARMNN_ASSERT_MSG(false, "Unimplemented"); } ConstTensorHandle(const ConstTensorHandle& other) = delete; ConstTensorHandle& operator=(const ConstTensorHandle& other) = delete; TensorInfo m_TensorInfo; const void* m_Memory; }; template<> const void* ConstTensorHandle::GetConstTensor() const; // Abstract specialization of ConstTensorHandle that allows write access to the same data. class TensorHandle : public ConstTensorHandle { public: template T* GetTensor() const { if (armnnUtils::CompatibleTypes(GetTensorInfo().GetDataType())) { return reinterpret_cast(m_MutableMemory); } else { throw armnn::Exception("Attempting to get not compatible type tensor!"); } } protected: TensorHandle(const TensorInfo& tensorInfo); void SetMemory(void* mem) { m_MutableMemory = mem; SetConstMemory(m_MutableMemory); } private: TensorHandle(const TensorHandle& other) = delete; TensorHandle& operator=(const TensorHandle& other) = delete; void* m_MutableMemory; }; template <> void* TensorHandle::GetTensor() const; // A TensorHandle that owns the wrapped memory region. class ScopedTensorHandle : public TensorHandle { public: explicit ScopedTensorHandle(const TensorInfo& tensorInfo); // Copies contents from Tensor. explicit ScopedTensorHandle(const ConstTensor& tensor); // Copies contents from ConstTensorHandle explicit ScopedTensorHandle(const ConstTensorHandle& tensorHandle); ScopedTensorHandle(const ScopedTensorHandle& other); ScopedTensorHandle& operator=(const ScopedTensorHandle& other); ~ScopedTensorHandle(); virtual void Allocate() override; private: // Only used for testing void CopyOutTo(void* memory) const override; void CopyInFrom(const void* memory) override; void CopyFrom(const ScopedTensorHandle& other); void CopyFrom(const void* srcMemory, unsigned int numBytes); }; // A TensorHandle that wraps an already allocated memory region. // // Clients must make sure the passed in memory region stays alive for the lifetime of // the PassthroughTensorHandle instance. // // Note there is no polymorphism to/from ConstPassthroughTensorHandle. class PassthroughTensorHandle : public TensorHandle { public: PassthroughTensorHandle(const TensorInfo& tensorInfo, void* mem) : TensorHandle(tensorInfo) { SetMemory(mem); } virtual void Allocate() override; }; // A ConstTensorHandle that wraps an already allocated memory region. // // This allows users to pass in const memory to a network. // Clients must make sure the passed in memory region stays alive for the lifetime of // the PassthroughTensorHandle instance. // // Note there is no polymorphism to/from PassthroughTensorHandle. class ConstPassthroughTensorHandle : public ConstTensorHandle { public: ConstPassthroughTensorHandle(const TensorInfo& tensorInfo, const void* mem) : ConstTensorHandle(tensorInfo) { SetConstMemory(mem); } virtual void Allocate() override; }; // Template specializations. template <> const void* ConstTensorHandle::GetConstTensor() const; template <> void* TensorHandle::GetTensor() const; class ManagedConstTensorHandle { public: explicit ManagedConstTensorHandle(std::shared_ptr ptr) : m_Mapped(false) , m_TensorHandle(std::move(ptr)) {}; /// RAII Managed resource Unmaps MemoryArea once out of scope const void* Map(bool blocking = true) { if (m_TensorHandle) { auto pRet = m_TensorHandle->Map(blocking); m_Mapped = true; return pRet; } else { throw armnn::Exception("Attempting to Map null TensorHandle"); } } // Delete copy constructor as it's unnecessary ManagedConstTensorHandle(const ConstTensorHandle& other) = delete; // Delete copy assignment as it's unnecessary ManagedConstTensorHandle& operator=(const ManagedConstTensorHandle& other) = delete; // Delete move assignment as it's unnecessary ManagedConstTensorHandle& operator=(ManagedConstTensorHandle&& other) noexcept = delete; ~ManagedConstTensorHandle() { // Bias tensor handles need to be initialized empty before entering scope of if statement checking if enabled if (m_TensorHandle) { Unmap(); } } void Unmap() { // Only unmap if mapped and TensorHandle exists. if (m_Mapped && m_TensorHandle) { m_TensorHandle->Unmap(); m_Mapped = false; } } const TensorInfo& GetTensorInfo() const { return m_TensorHandle->GetTensorInfo(); } bool IsMapped() const { return m_Mapped; } private: bool m_Mapped; std::shared_ptr m_TensorHandle; }; using ConstCpuTensorHandle ARMNN_DEPRECATED_MSG_REMOVAL_DATE("ConstCpuTensorHandle is deprecated, " "use ConstTensorHandle instead", "22.05") = ConstTensorHandle; using CpuTensorHandle ARMNN_DEPRECATED_MSG_REMOVAL_DATE("CpuTensorHandle is deprecated, " "use TensorHandle instead", "22.05") = TensorHandle; using ScopedCpuTensorHandle ARMNN_DEPRECATED_MSG_REMOVAL_DATE("ScopedCpuTensorHandle is deprecated, " "use ScopedTensorHandle instead", "22.05") = ScopedTensorHandle; using PassthroughCpuTensorHandle ARMNN_DEPRECATED_MSG_REMOVAL_DATE("PassthroughCpuTensorHandle is deprecated, use " "PassthroughTensorHandle instead", "22.05") = PassthroughTensorHandle; using ConstPassthroughCpuTensorHandle ARMNN_DEPRECATED_MSG_REMOVAL_DATE("ConstPassthroughCpuTensorHandle is " "deprecated, use ConstPassthroughTensorHandle " "instead", "22.05") = ConstPassthroughTensorHandle; } // namespace armnn