aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/backends/cl/ClTensorHandle.hpp57
1 files changed, 53 insertions, 4 deletions
diff --git a/src/backends/cl/ClTensorHandle.hpp b/src/backends/cl/ClTensorHandle.hpp
index 0302ef5790..2798de6d0c 100644
--- a/src/backends/cl/ClTensorHandle.hpp
+++ b/src/backends/cl/ClTensorHandle.hpp
@@ -35,23 +35,43 @@ class ClTensorHandle : public IClTensorHandle
{
public:
ClTensorHandle(const TensorInfo& tensorInfo)
+ : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined)),
+ m_Imported(false),
+ m_IsImportEnabled(false)
{
armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
}
- ClTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout)
+ ClTensorHandle(const TensorInfo& tensorInfo,
+ DataLayout dataLayout,
+ MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Undefined))
+ : m_ImportFlags(importFlags),
+ m_Imported(false),
+ m_IsImportEnabled(false)
{
armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
}
arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
- virtual void Allocate() override {armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);}
+ virtual void Allocate() override
+ {
+ // If we have enabled Importing, don't allocate the tensor
+ if (!m_IsImportEnabled)
+ {
+ armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
+ }
+
+ }
virtual void Manage() override
{
- assert(m_MemoryGroup != nullptr);
- m_MemoryGroup->manage(&m_Tensor);
+ // If we have enabled Importing, don't manage the tensor
+ if (!m_IsImportEnabled)
+ {
+ assert(m_MemoryGroup != nullptr);
+ m_MemoryGroup->manage(&m_Tensor);
+ }
}
virtual const void* Map(bool blocking = true) const override
@@ -84,6 +104,32 @@ public:
return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
}
+ void SetImportFlags(MemorySourceFlags importFlags)
+ {
+ m_ImportFlags = importFlags;
+ }
+
+ MemorySourceFlags GetImportFlags() const override
+ {
+ return m_ImportFlags;
+ }
+
+ void SetImportEnabledFlag(bool importEnabledFlag)
+ {
+ m_IsImportEnabled = importEnabledFlag;
+ }
+
+ virtual bool Import(void* memory, MemorySource source) override
+ {
+ armnn::IgnoreUnused(memory);
+ if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
+ {
+ throw MemoryImportException("ClTensorHandle::Incorrect import flag");
+ }
+ m_Imported = false;
+ return false;
+ }
+
private:
// Only used for testing
void CopyOutTo(void* memory) const override
@@ -169,6 +215,9 @@ private:
arm_compute::CLTensor m_Tensor;
std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
+ MemorySourceFlags m_ImportFlags;
+ bool m_Imported;
+ bool m_IsImportEnabled;
};
class ClSubTensorHandle : public IClTensorHandle