diff options
Diffstat (limited to 'src/backends/neon/NeonTensorHandle.hpp')
-rw-r--r-- | src/backends/neon/NeonTensorHandle.hpp | 168 |
1 files changed, 165 insertions, 3 deletions
diff --git a/src/backends/neon/NeonTensorHandle.hpp b/src/backends/neon/NeonTensorHandle.hpp index fcae77cdaa..e5f210773d 100644 --- a/src/backends/neon/NeonTensorHandle.hpp +++ b/src/backends/neon/NeonTensorHandle.hpp @@ -1,7 +1,8 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // + #pragma once #include <BFloat16.hpp> @@ -19,9 +20,11 @@ #include <arm_compute/runtime/SubTensor.h> #include <arm_compute/core/TensorShape.h> #include <arm_compute/core/Coordinates.h> +#include "armnn/TypesUtils.hpp" namespace armnn { +class NeonTensorHandleDecorator; class NeonTensorHandle : public IAclTensorHandle { @@ -125,7 +128,7 @@ public: virtual bool Import(void* memory, MemorySource source) override { - if (m_ImportFlags & static_cast<MemorySourceFlags>(source)) + if (m_ImportFlags& static_cast<MemorySourceFlags>(source)) { if (source == MemorySource::Malloc && m_IsImportEnabled) { @@ -181,6 +184,8 @@ public: return false; } + virtual std::shared_ptr<ITensorHandle> DecorateTensorHandle(const TensorInfo& tensorInfo) override; + private: // Only used for testing void CopyOutTo(void* memory) const override @@ -275,6 +280,7 @@ private: bool m_Imported; bool m_IsImportEnabled; const uintptr_t m_TypeAlignment; + std::vector<std::shared_ptr<NeonTensorHandleDecorator>> m_Decorated; }; class NeonSubTensorHandle : public IAclTensorHandle @@ -283,7 +289,7 @@ public: NeonSubTensorHandle(IAclTensorHandle* parent, const arm_compute::TensorShape& shape, const arm_compute::Coordinates& coords) - : m_Tensor(&parent->GetTensor(), shape, coords) + : m_Tensor(&parent->GetTensor(), shape, coords, true) { parentHandle = parent; } @@ -319,6 +325,11 @@ public: return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape()); } + virtual std::shared_ptr<ITensorHandle> DecorateTensorHandle(const TensorInfo&) override + { + return nullptr; + }; + private: // Only used for testing void CopyOutTo(void* memory) const override @@ -394,4 +405,155 @@ private: ITensorHandle* parentHandle = nullptr; }; +/// NeonTensorDecorator wraps an existing Neon tensor allowing us to override the TensorInfo for it +class NeonTensorDecorator : public arm_compute::ITensor +{ +public: + NeonTensorDecorator(); + + NeonTensorDecorator(arm_compute::ITensor* original, const TensorInfo& info); + + ~NeonTensorDecorator() = default; + + NeonTensorDecorator(const NeonTensorDecorator&) = delete; + + NeonTensorDecorator& operator=(const NeonTensorDecorator&) = delete; + + NeonTensorDecorator(NeonTensorDecorator&&) = default; + + NeonTensorDecorator& operator=(NeonTensorDecorator&&) = default; + + // Inherited methods overridden: + arm_compute::ITensorInfo* info() const override; + + arm_compute::ITensorInfo* info() override; + + uint8_t* buffer() const override; + +private: + arm_compute::ITensor* m_Original; + mutable arm_compute::TensorInfo m_TensorInfo; +}; + +class NeonTensorHandleDecorator : public IAclTensorHandle +{ +public: + NeonTensorHandleDecorator(IAclTensorHandle* parent, const TensorInfo& info) + : m_Tensor(&parent->GetTensor(), info) + { + parentHandle = parent; + } + + arm_compute::ITensor& GetTensor() override { return m_Tensor; } + arm_compute::ITensor const& GetTensor() const override { return m_Tensor; } + + virtual void Allocate() override {} + virtual void Manage() override {} + + virtual ITensorHandle* GetParent() const override { return nullptr; } + + virtual arm_compute::DataType GetDataType() const override + { + return m_Tensor.info()->data_type(); + } + + virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {} + + virtual const void* Map(bool /* blocking = true */) const override + { + return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes()); + } + virtual void Unmap() const override {} + + TensorShape GetStrides() const override + { + return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes()); + } + + TensorShape GetShape() const override + { + return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape()); + } + + virtual std::shared_ptr<ITensorHandle> DecorateTensorHandle(const TensorInfo&) override + { + return nullptr; + }; + +private: + // Only used for testing + void CopyOutTo(void* memory) const override + { + switch (this->GetDataType()) + { + case arm_compute::DataType::F32: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<float*>(memory)); + break; + case arm_compute::DataType::U8: + case arm_compute::DataType::QASYMM8: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<uint8_t*>(memory)); + break; + case arm_compute::DataType::QSYMM8: + case arm_compute::DataType::QASYMM8_SIGNED: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<int8_t*>(memory)); + break; + case arm_compute::DataType::S16: + case arm_compute::DataType::QSYMM16: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<int16_t*>(memory)); + break; + case arm_compute::DataType::S32: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<int32_t*>(memory)); + break; + default: + { + throw armnn::UnimplementedException(); + } + } + } + + // Only used for testing + void CopyInFrom(const void* memory) override + { + switch (this->GetDataType()) + { + case arm_compute::DataType::F32: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory), + this->GetTensor()); + break; + case arm_compute::DataType::U8: + case arm_compute::DataType::QASYMM8: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory), + this->GetTensor()); + break; + case arm_compute::DataType::QSYMM8: + case arm_compute::DataType::QASYMM8_SIGNED: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory), + this->GetTensor()); + break; + case arm_compute::DataType::S16: + case arm_compute::DataType::QSYMM16: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory), + this->GetTensor()); + break; + case arm_compute::DataType::S32: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory), + this->GetTensor()); + break; + default: + { + throw armnn::UnimplementedException(); + } + } + } + + NeonTensorDecorator m_Tensor; + ITensorHandle* parentHandle = nullptr; +}; + + } // namespace armnn |