aboutsummaryrefslogtreecommitdiff
path: root/src/backends/cl/ClTensorHandle.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/cl/ClTensorHandle.hpp')
-rw-r--r--src/backends/cl/ClTensorHandle.hpp184
1 files changed, 182 insertions, 2 deletions
diff --git a/src/backends/cl/ClTensorHandle.hpp b/src/backends/cl/ClTensorHandle.hpp
index 3d750f9059..42657341fd 100644
--- a/src/backends/cl/ClTensorHandle.hpp
+++ b/src/backends/cl/ClTensorHandle.hpp
@@ -1,7 +1,8 @@
//
-// Copyright © 2017, 2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
+
#pragma once
#include <aclCommon/ArmComputeTensorHandle.hpp>
@@ -22,6 +23,7 @@
namespace armnn
{
+class ClTensorHandleDecorator;
class ClTensorHandle : public IClTensorHandle
{
@@ -122,7 +124,7 @@ public:
virtual bool Import(void* memory, MemorySource source) override
{
armnn::IgnoreUnused(memory);
- if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
+ if (m_ImportFlags& static_cast<MemorySourceFlags>(source))
{
throw MemoryImportException("ClTensorHandle::Incorrect import flag");
}
@@ -137,6 +139,8 @@ public:
return false;
}
+ virtual std::shared_ptr<ITensorHandle> DecorateTensorHandle(const TensorInfo& tensorInfo) override;
+
private:
// Only used for testing
void CopyOutTo(void* memory) const override
@@ -227,6 +231,7 @@ private:
MemorySourceFlags m_ImportFlags;
bool m_Imported;
bool m_IsImportEnabled;
+ std::vector<std::shared_ptr<ClTensorHandleDecorator>> m_Decorated;
};
class ClSubTensorHandle : public IClTensorHandle
@@ -361,4 +366,179 @@ private:
ITensorHandle* parentHandle = nullptr;
};
+/** ClTensorDecorator wraps an existing CL tensor allowing us to override the TensorInfo for it */
+class ClTensorDecorator : public arm_compute::ICLTensor
+{
+public:
+ ClTensorDecorator();
+
+ ClTensorDecorator(arm_compute::ICLTensor* original, const TensorInfo& info);
+
+ ~ClTensorDecorator() = default;
+
+ ClTensorDecorator(const ClTensorDecorator&) = delete;
+
+ ClTensorDecorator& operator=(const ClTensorDecorator&) = delete;
+
+ ClTensorDecorator(ClTensorDecorator&&) = default;
+
+ ClTensorDecorator& operator=(ClTensorDecorator&&) = default;
+
+ arm_compute::ICLTensor* parent();
+
+ void map(bool blocking = true);
+ using arm_compute::ICLTensor::map;
+
+ void unmap();
+ using arm_compute::ICLTensor::unmap;
+
+ virtual arm_compute::ITensorInfo* info() const override;
+ virtual arm_compute::ITensorInfo* info() override;
+ const cl::Buffer& cl_buffer() const override;
+ arm_compute::CLQuantization quantization() const override;
+
+protected:
+ // Inherited methods overridden:
+ uint8_t* do_map(cl::CommandQueue& q, bool blocking) override;
+ void do_unmap(cl::CommandQueue& q) override;
+
+private:
+ arm_compute::ICLTensor* m_Original;
+ mutable arm_compute::TensorInfo m_TensorInfo;
+};
+
+class ClTensorHandleDecorator : public IClTensorHandle
+{
+public:
+ ClTensorHandleDecorator(IClTensorHandle* parent, const TensorInfo& info)
+ : m_Tensor(&parent->GetTensor(), info)
+ {
+ m_OriginalHandle = parent;
+ }
+
+ arm_compute::ICLTensor& GetTensor() override { return m_Tensor; }
+ arm_compute::ICLTensor const& GetTensor() const override { return m_Tensor; }
+
+ virtual void Allocate() override {}
+ virtual void Manage() override {}
+
+ virtual const void* Map(bool blocking = true) const override
+ {
+ m_Tensor.map(blocking);
+ return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
+ }
+
+ virtual void Unmap() const override
+ {
+ m_Tensor.unmap();
+ }
+
+ virtual ITensorHandle* GetParent() const override { return nullptr; }
+
+ virtual arm_compute::DataType GetDataType() const override
+ {
+ return m_Tensor.info()->data_type();
+ }
+
+ virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
+
+ TensorShape GetStrides() const override
+ {
+ return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
+ }
+
+ TensorShape GetShape() const override
+ {
+ return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
+ }
+
+private:
+ // Only used for testing
+ void CopyOutTo(void* memory) const override
+ {
+ const_cast<ClTensorHandleDecorator*>(this)->Map(true);
+ switch(this->GetDataType())
+ {
+ case arm_compute::DataType::F32:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<float*>(memory));
+ break;
+ case arm_compute::DataType::U8:
+ case arm_compute::DataType::QASYMM8:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<uint8_t*>(memory));
+ break;
+ case arm_compute::DataType::F16:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<armnn::Half*>(memory));
+ break;
+ case arm_compute::DataType::QSYMM8:
+ case arm_compute::DataType::QSYMM8_PER_CHANNEL:
+ case arm_compute::DataType::QASYMM8_SIGNED:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<int8_t*>(memory));
+ break;
+ case arm_compute::DataType::S16:
+ case arm_compute::DataType::QSYMM16:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<int16_t*>(memory));
+ break;
+ case arm_compute::DataType::S32:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<int32_t*>(memory));
+ break;
+ default:
+ {
+ throw armnn::UnimplementedException();
+ }
+ }
+ const_cast<ClTensorHandleDecorator*>(this)->Unmap();
+ }
+
+ // Only used for testing
+ void CopyInFrom(const void* memory) override
+ {
+ this->Map(true);
+ switch(this->GetDataType())
+ {
+ case arm_compute::DataType::F32:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
+ this->GetTensor());
+ break;
+ case arm_compute::DataType::U8:
+ case arm_compute::DataType::QASYMM8:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
+ this->GetTensor());
+ break;
+ case arm_compute::DataType::F16:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
+ this->GetTensor());
+ break;
+ case arm_compute::DataType::QSYMM8:
+ case arm_compute::DataType::QSYMM8_PER_CHANNEL:
+ case arm_compute::DataType::QASYMM8_SIGNED:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
+ this->GetTensor());
+ break;
+ case arm_compute::DataType::S16:
+ case arm_compute::DataType::QSYMM16:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
+ this->GetTensor());
+ break;
+ case arm_compute::DataType::S32:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
+ this->GetTensor());
+ break;
+ default:
+ {
+ throw armnn::UnimplementedException();
+ }
+ }
+ this->Unmap();
+ }
+
+ mutable ClTensorDecorator m_Tensor;
+ IClTensorHandle* m_OriginalHandle = nullptr;
+};
+
} // namespace armnn