diff options
author | Mike Kelly <mike.kelly@arm.com> | 2023-07-07 15:43:06 +0100 |
---|---|---|
committer | Mike Kelly <mike.kelly@arm.com> | 2023-07-14 00:00:53 +0100 |
commit | 4cc341cf8b5a6e6bb0543504cbbfde6fa11a2cdb (patch) | |
tree | 7cac128e9ec6f2fd27f1afdb55f44b870f39e0b3 /src/backends/cl/ClTensorHandle.hpp | |
parent | 6963b33221c23af4a8eff19ff4a5773230b0befd (diff) | |
download | armnn-4cc341cf8b5a6e6bb0543504cbbfde6fa11a2cdb.tar.gz |
IVGCVSW-7830 Add backend optimizations to remove Reshapes where possible
* Added optimization to remove reshapes for Neon and Ref Backends
by using overridden TensorInfos
* Added ability to delete Subgraphs during Optimization
* Fixed naming error in NeonEndToEndTests and CLEndToEndTests
* Added LayerNameAndTypeCheck for testing.
* Fixed error where layers were not marked as altered when removed in
CLBackend
Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: I1ac25cd4ec9821470d961831ae2c8d24882276cc
Diffstat (limited to 'src/backends/cl/ClTensorHandle.hpp')
-rw-r--r-- | src/backends/cl/ClTensorHandle.hpp | 184 |
1 files changed, 182 insertions, 2 deletions
diff --git a/src/backends/cl/ClTensorHandle.hpp b/src/backends/cl/ClTensorHandle.hpp index 3d750f9059..42657341fd 100644 --- a/src/backends/cl/ClTensorHandle.hpp +++ b/src/backends/cl/ClTensorHandle.hpp @@ -1,7 +1,8 @@ // -// Copyright © 2017, 2023 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // + #pragma once #include <aclCommon/ArmComputeTensorHandle.hpp> @@ -22,6 +23,7 @@ namespace armnn { +class ClTensorHandleDecorator; class ClTensorHandle : public IClTensorHandle { @@ -122,7 +124,7 @@ public: virtual bool Import(void* memory, MemorySource source) override { armnn::IgnoreUnused(memory); - if (m_ImportFlags & static_cast<MemorySourceFlags>(source)) + if (m_ImportFlags& static_cast<MemorySourceFlags>(source)) { throw MemoryImportException("ClTensorHandle::Incorrect import flag"); } @@ -137,6 +139,8 @@ public: return false; } + virtual std::shared_ptr<ITensorHandle> DecorateTensorHandle(const TensorInfo& tensorInfo) override; + private: // Only used for testing void CopyOutTo(void* memory) const override @@ -227,6 +231,7 @@ private: MemorySourceFlags m_ImportFlags; bool m_Imported; bool m_IsImportEnabled; + std::vector<std::shared_ptr<ClTensorHandleDecorator>> m_Decorated; }; class ClSubTensorHandle : public IClTensorHandle @@ -361,4 +366,179 @@ private: ITensorHandle* parentHandle = nullptr; }; +/** ClTensorDecorator wraps an existing CL tensor allowing us to override the TensorInfo for it */ +class ClTensorDecorator : public arm_compute::ICLTensor +{ +public: + ClTensorDecorator(); + + ClTensorDecorator(arm_compute::ICLTensor* original, const TensorInfo& info); + + ~ClTensorDecorator() = default; + + ClTensorDecorator(const ClTensorDecorator&) = delete; + + ClTensorDecorator& operator=(const ClTensorDecorator&) = delete; + + ClTensorDecorator(ClTensorDecorator&&) = default; + + ClTensorDecorator& operator=(ClTensorDecorator&&) = default; + + arm_compute::ICLTensor* parent(); + + void map(bool blocking = true); + using arm_compute::ICLTensor::map; + + void unmap(); + using arm_compute::ICLTensor::unmap; + + virtual arm_compute::ITensorInfo* info() const override; + virtual arm_compute::ITensorInfo* info() override; + const cl::Buffer& cl_buffer() const override; + arm_compute::CLQuantization quantization() const override; + +protected: + // Inherited methods overridden: + uint8_t* do_map(cl::CommandQueue& q, bool blocking) override; + void do_unmap(cl::CommandQueue& q) override; + +private: + arm_compute::ICLTensor* m_Original; + mutable arm_compute::TensorInfo m_TensorInfo; +}; + +class ClTensorHandleDecorator : public IClTensorHandle +{ +public: + ClTensorHandleDecorator(IClTensorHandle* parent, const TensorInfo& info) + : m_Tensor(&parent->GetTensor(), info) + { + m_OriginalHandle = parent; + } + + arm_compute::ICLTensor& GetTensor() override { return m_Tensor; } + arm_compute::ICLTensor const& GetTensor() const override { return m_Tensor; } + + virtual void Allocate() override {} + virtual void Manage() override {} + + virtual const void* Map(bool blocking = true) const override + { + m_Tensor.map(blocking); + return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes()); + } + + virtual void Unmap() const override + { + m_Tensor.unmap(); + } + + virtual ITensorHandle* GetParent() const override { return nullptr; } + + virtual arm_compute::DataType GetDataType() const override + { + return m_Tensor.info()->data_type(); + } + + virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {} + + TensorShape GetStrides() const override + { + return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes()); + } + + TensorShape GetShape() const override + { + return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape()); + } + +private: + // Only used for testing + void CopyOutTo(void* memory) const override + { + const_cast<ClTensorHandleDecorator*>(this)->Map(true); + switch(this->GetDataType()) + { + case arm_compute::DataType::F32: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<float*>(memory)); + break; + case arm_compute::DataType::U8: + case arm_compute::DataType::QASYMM8: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<uint8_t*>(memory)); + break; + case arm_compute::DataType::F16: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<armnn::Half*>(memory)); + break; + case arm_compute::DataType::QSYMM8: + case arm_compute::DataType::QSYMM8_PER_CHANNEL: + case arm_compute::DataType::QASYMM8_SIGNED: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<int8_t*>(memory)); + break; + case arm_compute::DataType::S16: + case arm_compute::DataType::QSYMM16: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<int16_t*>(memory)); + break; + case arm_compute::DataType::S32: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<int32_t*>(memory)); + break; + default: + { + throw armnn::UnimplementedException(); + } + } + const_cast<ClTensorHandleDecorator*>(this)->Unmap(); + } + + // Only used for testing + void CopyInFrom(const void* memory) override + { + this->Map(true); + switch(this->GetDataType()) + { + case arm_compute::DataType::F32: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory), + this->GetTensor()); + break; + case arm_compute::DataType::U8: + case arm_compute::DataType::QASYMM8: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory), + this->GetTensor()); + break; + case arm_compute::DataType::F16: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory), + this->GetTensor()); + break; + case arm_compute::DataType::QSYMM8: + case arm_compute::DataType::QSYMM8_PER_CHANNEL: + case arm_compute::DataType::QASYMM8_SIGNED: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory), + this->GetTensor()); + break; + case arm_compute::DataType::S16: + case arm_compute::DataType::QSYMM16: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory), + this->GetTensor()); + break; + case arm_compute::DataType::S32: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory), + this->GetTensor()); + break; + default: + { + throw armnn::UnimplementedException(); + } + } + this->Unmap(); + } + + mutable ClTensorDecorator m_Tensor; + IClTensorHandle* m_OriginalHandle = nullptr; +}; + } // namespace armnn |