aboutsummaryrefslogtreecommitdiff
path: root/src/backends/cl/ClTensorHandle.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/cl/ClTensorHandle.hpp')
-rw-r--r--src/backends/cl/ClTensorHandle.hpp114
1 files changed, 108 insertions, 6 deletions
diff --git a/src/backends/cl/ClTensorHandle.hpp b/src/backends/cl/ClTensorHandle.hpp
index 0f1f583bbe..f791ee8fc9 100644
--- a/src/backends/cl/ClTensorHandle.hpp
+++ b/src/backends/cl/ClTensorHandle.hpp
@@ -7,6 +7,8 @@
#include <backendsCommon/OutputHandler.hpp>
#include <aclCommon/ArmComputeTensorUtils.hpp>
+#include <Half.hpp>
+
#include <arm_compute/runtime/CL/CLTensor.h>
#include <arm_compute/runtime/CL/CLSubTensor.h>
#include <arm_compute/runtime/CL/CLMemoryGroup.h>
@@ -59,8 +61,6 @@ public:
}
virtual void Unmap() const override { const_cast<arm_compute::CLTensor*>(&m_Tensor)->unmap(); }
- virtual ITensorHandle::Type GetType() const override { return ITensorHandle::CL; }
-
virtual ITensorHandle* GetParent() const override { return nullptr; }
virtual arm_compute::DataType GetDataType() const override
@@ -82,7 +82,60 @@ public:
{
return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
}
+
private:
+ // Only used for testing
+ void CopyOutTo(void* memory) const override
+ {
+ const_cast<armnn::ClTensorHandle*>(this)->Map(true);
+ switch(this->GetDataType())
+ {
+ case arm_compute::DataType::F32:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<float*>(memory));
+ break;
+ case arm_compute::DataType::QASYMM8:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<uint8_t*>(memory));
+ break;
+ case arm_compute::DataType::F16:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<armnn::Half*>(memory));
+ break;
+ default:
+ {
+ throw armnn::UnimplementedException();
+ }
+ }
+ const_cast<armnn::ClTensorHandle*>(this)->Unmap();
+ }
+
+ // Only used for testing
+ void CopyInFrom(const void* memory) override
+ {
+ this->Map(true);
+ switch(this->GetDataType())
+ {
+ case arm_compute::DataType::F32:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
+ this->GetTensor());
+ break;
+ case arm_compute::DataType::QASYMM8:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
+ this->GetTensor());
+ break;
+ case arm_compute::DataType::F16:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
+ this->GetTensor());
+ break;
+ default:
+ {
+ throw armnn::UnimplementedException();
+ }
+ }
+ this->Unmap();
+ }
+
arm_compute::CLTensor m_Tensor;
std::shared_ptr<arm_compute::CLMemoryGroup> m_MemoryGroup;
};
@@ -111,8 +164,6 @@ public:
}
virtual void Unmap() const override { const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->unmap(); }
- virtual ITensorHandle::Type GetType() const override { return ITensorHandle::CL; }
-
virtual ITensorHandle* GetParent() const override { return parentHandle; }
virtual arm_compute::DataType GetDataType() const override
@@ -133,9 +184,60 @@ public:
}
private:
+ // Only used for testing
+ void CopyOutTo(void* memory) const override
+ {
+ const_cast<ClSubTensorHandle*>(this)->Map(true);
+ switch(this->GetDataType())
+ {
+ case arm_compute::DataType::F32:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<float*>(memory));
+ break;
+ case arm_compute::DataType::QASYMM8:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<uint8_t*>(memory));
+ break;
+ case arm_compute::DataType::F16:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<armnn::Half*>(memory));
+ break;
+ default:
+ {
+ throw armnn::UnimplementedException();
+ }
+ }
+ const_cast<ClSubTensorHandle*>(this)->Unmap();
+ }
+
+ // Only used for testing
+ void CopyInFrom(const void* memory) override
+ {
+ this->Map(true);
+ switch(this->GetDataType())
+ {
+ case arm_compute::DataType::F32:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
+ this->GetTensor());
+ break;
+ case arm_compute::DataType::QASYMM8:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
+ this->GetTensor());
+ break;
+ case arm_compute::DataType::F16:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
+ this->GetTensor());
+ break;
+ default:
+ {
+ throw armnn::UnimplementedException();
+ }
+ }
+ this->Unmap();
+ }
+
mutable arm_compute::CLSubTensor m_Tensor;
ITensorHandle* parentHandle = nullptr;
-
};
-}
+} // namespace armnn \ No newline at end of file