diff options
Diffstat (limited to 'src/backends/cl/ClTensorHandle.hpp')
-rw-r--r-- | src/backends/cl/ClTensorHandle.hpp | 114 |
1 files changed, 108 insertions, 6 deletions
diff --git a/src/backends/cl/ClTensorHandle.hpp b/src/backends/cl/ClTensorHandle.hpp index 0f1f583bbe..f791ee8fc9 100644 --- a/src/backends/cl/ClTensorHandle.hpp +++ b/src/backends/cl/ClTensorHandle.hpp @@ -7,6 +7,8 @@ #include <backendsCommon/OutputHandler.hpp> #include <aclCommon/ArmComputeTensorUtils.hpp> +#include <Half.hpp> + #include <arm_compute/runtime/CL/CLTensor.h> #include <arm_compute/runtime/CL/CLSubTensor.h> #include <arm_compute/runtime/CL/CLMemoryGroup.h> @@ -59,8 +61,6 @@ public: } virtual void Unmap() const override { const_cast<arm_compute::CLTensor*>(&m_Tensor)->unmap(); } - virtual ITensorHandle::Type GetType() const override { return ITensorHandle::CL; } - virtual ITensorHandle* GetParent() const override { return nullptr; } virtual arm_compute::DataType GetDataType() const override @@ -82,7 +82,60 @@ public: { return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape()); } + private: + // Only used for testing + void CopyOutTo(void* memory) const override + { + const_cast<armnn::ClTensorHandle*>(this)->Map(true); + switch(this->GetDataType()) + { + case arm_compute::DataType::F32: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<float*>(memory)); + break; + case arm_compute::DataType::QASYMM8: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<uint8_t*>(memory)); + break; + case arm_compute::DataType::F16: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<armnn::Half*>(memory)); + break; + default: + { + throw armnn::UnimplementedException(); + } + } + const_cast<armnn::ClTensorHandle*>(this)->Unmap(); + } + + // Only used for testing + void CopyInFrom(const void* memory) override + { + this->Map(true); + switch(this->GetDataType()) + { + case arm_compute::DataType::F32: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory), + this->GetTensor()); + break; + case arm_compute::DataType::QASYMM8: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory), + this->GetTensor()); + break; + case arm_compute::DataType::F16: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory), + this->GetTensor()); + break; + default: + { + throw armnn::UnimplementedException(); + } + } + this->Unmap(); + } + arm_compute::CLTensor m_Tensor; std::shared_ptr<arm_compute::CLMemoryGroup> m_MemoryGroup; }; @@ -111,8 +164,6 @@ public: } virtual void Unmap() const override { const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->unmap(); } - virtual ITensorHandle::Type GetType() const override { return ITensorHandle::CL; } - virtual ITensorHandle* GetParent() const override { return parentHandle; } virtual arm_compute::DataType GetDataType() const override @@ -133,9 +184,60 @@ public: } private: + // Only used for testing + void CopyOutTo(void* memory) const override + { + const_cast<ClSubTensorHandle*>(this)->Map(true); + switch(this->GetDataType()) + { + case arm_compute::DataType::F32: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<float*>(memory)); + break; + case arm_compute::DataType::QASYMM8: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<uint8_t*>(memory)); + break; + case arm_compute::DataType::F16: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<armnn::Half*>(memory)); + break; + default: + { + throw armnn::UnimplementedException(); + } + } + const_cast<ClSubTensorHandle*>(this)->Unmap(); + } + + // Only used for testing + void CopyInFrom(const void* memory) override + { + this->Map(true); + switch(this->GetDataType()) + { + case arm_compute::DataType::F32: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory), + this->GetTensor()); + break; + case arm_compute::DataType::QASYMM8: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory), + this->GetTensor()); + break; + case arm_compute::DataType::F16: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory), + this->GetTensor()); + break; + default: + { + throw armnn::UnimplementedException(); + } + } + this->Unmap(); + } + mutable arm_compute::CLSubTensor m_Tensor; ITensorHandle* parentHandle = nullptr; - }; -} +} // namespace armnn
\ No newline at end of file |