// // Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include #include #include #include #include #include #include #include #include #include #include namespace armnn { class ClTensorHandleDecorator; class ClTensorHandle : public IClTensorHandle { public: ClTensorHandle(const TensorInfo& tensorInfo) : m_ImportFlags(static_cast(MemorySource::Undefined)), m_Imported(false), m_IsImportEnabled(false) { armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo); } ClTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout, MemorySourceFlags importFlags = static_cast(MemorySource::Undefined)) : m_ImportFlags(importFlags), m_Imported(false), m_IsImportEnabled(false) { armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout); } arm_compute::CLTensor& GetTensor() override { return m_Tensor; } arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; } virtual void Allocate() override { // If we have enabled Importing, don't allocate the tensor if (m_IsImportEnabled) { throw MemoryImportException("ClTensorHandle::Attempting to allocate memory when importing"); } else { armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor); } } virtual void Manage() override { // If we have enabled Importing, don't manage the tensor if (m_IsImportEnabled) { throw MemoryImportException("ClTensorHandle::Attempting to manage memory when importing"); } else { assert(m_MemoryGroup != nullptr); m_MemoryGroup->manage(&m_Tensor); } } virtual const void* Map(bool blocking = true) const override { const_cast(&m_Tensor)->map(blocking); return static_cast(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes()); } virtual void Unmap() const override { const_cast(&m_Tensor)->unmap(); } virtual ITensorHandle* GetParent() const override { return nullptr; } virtual arm_compute::DataType GetDataType() const override { return m_Tensor.info()->data_type(); } virtual void SetMemoryGroup(const std::shared_ptr& memoryGroup) override { m_MemoryGroup = PolymorphicPointerDowncast(memoryGroup); } TensorShape GetStrides() const override { return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes()); } TensorShape GetShape() const override { return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape()); } void SetImportFlags(MemorySourceFlags importFlags) { m_ImportFlags = importFlags; } MemorySourceFlags GetImportFlags() const override { return m_ImportFlags; } void SetImportEnabledFlag(bool importEnabledFlag) { m_IsImportEnabled = importEnabledFlag; } virtual bool Import(void* memory, MemorySource source) override { armnn::IgnoreUnused(memory); if (m_ImportFlags& static_cast(source)) { throw MemoryImportException("ClTensorHandle::Incorrect import flag"); } m_Imported = false; return false; } virtual bool CanBeImported(void* memory, MemorySource source) override { // This TensorHandle can never import. armnn::IgnoreUnused(memory, source); return false; } virtual std::shared_ptr DecorateTensorHandle(const TensorInfo& tensorInfo) override; private: // Only used for testing void CopyOutTo(void* memory) const override { const_cast(this)->Map(true); switch(this->GetDataType()) { case arm_compute::DataType::F32: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast(memory)); break; case arm_compute::DataType::U8: case arm_compute::DataType::QASYMM8: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast(memory)); break; case arm_compute::DataType::QSYMM8: case arm_compute::DataType::QSYMM8_PER_CHANNEL: case arm_compute::DataType::QASYMM8_SIGNED: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast(memory)); break; case arm_compute::DataType::F16: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast(memory)); break; case arm_compute::DataType::S16: case arm_compute::DataType::QSYMM16: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast(memory)); break; case arm_compute::DataType::S32: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast(memory)); break; default: { throw armnn::UnimplementedException(); } } const_cast(this)->Unmap(); } // Only used for testing void CopyInFrom(const void* memory) override { this->Map(true); switch(this->GetDataType()) { case arm_compute::DataType::F32: armcomputetensorutils::CopyArmComputeITensorData(static_cast(memory), this->GetTensor()); break; case arm_compute::DataType::U8: case arm_compute::DataType::QASYMM8: armcomputetensorutils::CopyArmComputeITensorData(static_cast(memory), this->GetTensor()); break; case arm_compute::DataType::F16: armcomputetensorutils::CopyArmComputeITensorData(static_cast(memory), this->GetTensor()); break; case arm_compute::DataType::S16: case arm_compute::DataType::QSYMM8: case arm_compute::DataType::QSYMM8_PER_CHANNEL: case arm_compute::DataType::QASYMM8_SIGNED: armcomputetensorutils::CopyArmComputeITensorData(static_cast(memory), this->GetTensor()); break; case arm_compute::DataType::QSYMM16: armcomputetensorutils::CopyArmComputeITensorData(static_cast(memory), this->GetTensor()); break; case arm_compute::DataType::S32: armcomputetensorutils::CopyArmComputeITensorData(static_cast(memory), this->GetTensor()); break; default: { throw armnn::UnimplementedException(); } } this->Unmap(); } arm_compute::CLTensor m_Tensor; std::shared_ptr m_MemoryGroup; MemorySourceFlags m_ImportFlags; bool m_Imported; bool m_IsImportEnabled; std::vector> m_Decorated; }; class ClSubTensorHandle : public IClTensorHandle { public: ClSubTensorHandle(IClTensorHandle* parent, const arm_compute::TensorShape& shape, const arm_compute::Coordinates& coords) : m_Tensor(&parent->GetTensor(), shape, coords) { parentHandle = parent; } arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; } arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; } virtual void Allocate() override {} virtual void Manage() override {} virtual const void* Map(bool blocking = true) const override { const_cast(&m_Tensor)->map(blocking); return static_cast(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes()); } virtual void Unmap() const override { const_cast(&m_Tensor)->unmap(); } virtual ITensorHandle* GetParent() const override { return parentHandle; } virtual arm_compute::DataType GetDataType() const override { return m_Tensor.info()->data_type(); } virtual void SetMemoryGroup(const std::shared_ptr&) override {} TensorShape GetStrides() const override { return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes()); } TensorShape GetShape() const override { return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape()); } private: // Only used for testing void CopyOutTo(void* memory) const override { const_cast(this)->Map(true); switch(this->GetDataType()) { case arm_compute::DataType::F32: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast(memory)); break; case arm_compute::DataType::U8: case arm_compute::DataType::QASYMM8: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast(memory)); break; case arm_compute::DataType::F16: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast(memory)); break; case arm_compute::DataType::QSYMM8: case arm_compute::DataType::QSYMM8_PER_CHANNEL: case arm_compute::DataType::QASYMM8_SIGNED: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast(memory)); break; case arm_compute::DataType::S16: case arm_compute::DataType::QSYMM16: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast(memory)); break; case arm_compute::DataType::S32: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast(memory)); break; default: { throw armnn::UnimplementedException(); } } const_cast(this)->Unmap(); } // Only used for testing void CopyInFrom(const void* memory) override { this->Map(true); switch(this->GetDataType()) { case arm_compute::DataType::F32: armcomputetensorutils::CopyArmComputeITensorData(static_cast(memory), this->GetTensor()); break; case arm_compute::DataType::U8: case arm_compute::DataType::QASYMM8: armcomputetensorutils::CopyArmComputeITensorData(static_cast(memory), this->GetTensor()); break; case arm_compute::DataType::F16: armcomputetensorutils::CopyArmComputeITensorData(static_cast(memory), this->GetTensor()); break; case arm_compute::DataType::QSYMM8: case arm_compute::DataType::QSYMM8_PER_CHANNEL: case arm_compute::DataType::QASYMM8_SIGNED: armcomputetensorutils::CopyArmComputeITensorData(static_cast(memory), this->GetTensor()); break; case arm_compute::DataType::S16: case arm_compute::DataType::QSYMM16: armcomputetensorutils::CopyArmComputeITensorData(static_cast(memory), this->GetTensor()); break; case arm_compute::DataType::S32: armcomputetensorutils::CopyArmComputeITensorData(static_cast(memory), this->GetTensor()); break; default: { throw armnn::UnimplementedException(); } } this->Unmap(); } mutable arm_compute::CLSubTensor m_Tensor; ITensorHandle* parentHandle = nullptr; }; /** ClTensorDecorator wraps an existing CL tensor allowing us to override the TensorInfo for it */ class ClTensorDecorator : public arm_compute::ICLTensor { public: ClTensorDecorator(); ClTensorDecorator(arm_compute::ICLTensor* original, const TensorInfo& info); ~ClTensorDecorator() = default; ClTensorDecorator(const ClTensorDecorator&) = delete; ClTensorDecorator& operator=(const ClTensorDecorator&) = delete; ClTensorDecorator(ClTensorDecorator&&) = default; ClTensorDecorator& operator=(ClTensorDecorator&&) = default; arm_compute::ICLTensor* parent(); void map(bool blocking = true); using arm_compute::ICLTensor::map; void unmap(); using arm_compute::ICLTensor::unmap; virtual arm_compute::ITensorInfo* info() const override; virtual arm_compute::ITensorInfo* info() override; const cl::Buffer& cl_buffer() const override; arm_compute::CLQuantization quantization() const override; protected: // Inherited methods overridden: uint8_t* do_map(cl::CommandQueue& q, bool blocking) override; void do_unmap(cl::CommandQueue& q) override; private: arm_compute::ICLTensor* m_Original; mutable arm_compute::TensorInfo m_TensorInfo; }; class ClTensorHandleDecorator : public IClTensorHandle { public: ClTensorHandleDecorator(IClTensorHandle* parent, const TensorInfo& info) : m_Tensor(&parent->GetTensor(), info) { m_OriginalHandle = parent; } arm_compute::ICLTensor& GetTensor() override { return m_Tensor; } arm_compute::ICLTensor const& GetTensor() const override { return m_Tensor; } virtual void Allocate() override {} virtual void Manage() override {} virtual const void* Map(bool blocking = true) const override { m_Tensor.map(blocking); return static_cast(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes()); } virtual void Unmap() const override { m_Tensor.unmap(); } virtual ITensorHandle* GetParent() const override { return nullptr; } virtual arm_compute::DataType GetDataType() const override { return m_Tensor.info()->data_type(); } virtual void SetMemoryGroup(const std::shared_ptr&) override {} TensorShape GetStrides() const override { return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes()); } TensorShape GetShape() const override { return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape()); } private: // Only used for testing void CopyOutTo(void* memory) const override { const_cast(this)->Map(true); switch(this->GetDataType()) { case arm_compute::DataType::F32: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast(memory)); break; case arm_compute::DataType::U8: case arm_compute::DataType::QASYMM8: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast(memory)); break; case arm_compute::DataType::F16: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast(memory)); break; case arm_compute::DataType::QSYMM8: case arm_compute::DataType::QSYMM8_PER_CHANNEL: case arm_compute::DataType::QASYMM8_SIGNED: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast(memory)); break; case arm_compute::DataType::S16: case arm_compute::DataType::QSYMM16: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast(memory)); break; case arm_compute::DataType::S32: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast(memory)); break; default: { throw armnn::UnimplementedException(); } } const_cast(this)->Unmap(); } // Only used for testing void CopyInFrom(const void* memory) override { this->Map(true); switch(this->GetDataType()) { case arm_compute::DataType::F32: armcomputetensorutils::CopyArmComputeITensorData(static_cast(memory), this->GetTensor()); break; case arm_compute::DataType::U8: case arm_compute::DataType::QASYMM8: armcomputetensorutils::CopyArmComputeITensorData(static_cast(memory), this->GetTensor()); break; case arm_compute::DataType::F16: armcomputetensorutils::CopyArmComputeITensorData(static_cast(memory), this->GetTensor()); break; case arm_compute::DataType::QSYMM8: case arm_compute::DataType::QSYMM8_PER_CHANNEL: case arm_compute::DataType::QASYMM8_SIGNED: armcomputetensorutils::CopyArmComputeITensorData(static_cast(memory), this->GetTensor()); break; case arm_compute::DataType::S16: case arm_compute::DataType::QSYMM16: armcomputetensorutils::CopyArmComputeITensorData(static_cast(memory), this->GetTensor()); break; case arm_compute::DataType::S32: armcomputetensorutils::CopyArmComputeITensorData(static_cast(memory), this->GetTensor()); break; default: { throw armnn::UnimplementedException(); } } this->Unmap(); } mutable ClTensorDecorator m_Tensor; IClTensorHandle* m_OriginalHandle = nullptr; }; } // namespace armnn