aboutsummaryrefslogtreecommitdiff
path: root/src/backends
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends')
-rw-r--r--src/backends/backendsCommon/CpuTensorHandle.hpp67
-rw-r--r--src/backends/backendsCommon/test/DefaultAsyncExecuteTest.cpp1
-rw-r--r--src/backends/reference/test/RefTensorHandleTests.cpp33
3 files changed, 100 insertions, 1 deletions
diff --git a/src/backends/backendsCommon/CpuTensorHandle.hpp b/src/backends/backendsCommon/CpuTensorHandle.hpp
index a300fe09c2..fdd2439b41 100644
--- a/src/backends/backendsCommon/CpuTensorHandle.hpp
+++ b/src/backends/backendsCommon/CpuTensorHandle.hpp
@@ -175,4 +175,71 @@ const void* ConstCpuTensorHandle::GetConstTensor() const;
template <>
void* CpuTensorHandle::GetTensor() const;
+class ManagedConstTensorHandle
+{
+
+public:
+ explicit ManagedConstTensorHandle(std::shared_ptr<ConstCpuTensorHandle> ptr)
+ : m_Mapped(false)
+ , m_TensorHandle(std::move(ptr)) {};
+
+ /// RAII Managed resource Unmaps MemoryArea once out of scope
+ const void* Map(bool blocking = true)
+ {
+ if (m_TensorHandle)
+ {
+ auto pRet = m_TensorHandle->Map(blocking);
+ m_Mapped = true;
+ return pRet;
+ }
+ else
+ {
+ throw armnn::Exception("Attempting to Map null TensorHandle");
+ }
+
+ }
+
+ // Delete copy constructor as it's unnecessary
+ ManagedConstTensorHandle(const ConstCpuTensorHandle& other) = delete;
+
+ // Delete copy assignment as it's unnecessary
+ ManagedConstTensorHandle& operator=(const ManagedConstTensorHandle& other) = delete;
+
+ // Delete move assignment as it's unnecessary
+ ManagedConstTensorHandle& operator=(ManagedConstTensorHandle&& other) noexcept = delete;
+
+ ~ManagedConstTensorHandle()
+ {
+ // Bias tensor handles need to be initialized empty before entering scope of if statement checking if enabled
+ if (m_TensorHandle)
+ {
+ Unmap();
+ }
+ }
+
+ void Unmap()
+ {
+ // Only unmap if mapped and TensorHandle exists.
+ if (m_Mapped && m_TensorHandle)
+ {
+ m_TensorHandle->Unmap();
+ m_Mapped = false;
+ }
+ }
+
+ const TensorInfo& GetTensorInfo() const
+ {
+ return m_TensorHandle->GetTensorInfo();
+ }
+
+ bool IsMapped() const
+ {
+ return m_Mapped;
+ }
+
+private:
+ bool m_Mapped;
+ std::shared_ptr<ConstCpuTensorHandle> m_TensorHandle;
+};
+
} // namespace armnn
diff --git a/src/backends/backendsCommon/test/DefaultAsyncExecuteTest.cpp b/src/backends/backendsCommon/test/DefaultAsyncExecuteTest.cpp
index 0d4595210e..56a794e77c 100644
--- a/src/backends/backendsCommon/test/DefaultAsyncExecuteTest.cpp
+++ b/src/backends/backendsCommon/test/DefaultAsyncExecuteTest.cpp
@@ -243,7 +243,6 @@ BOOST_AUTO_TEST_CASE(TestDefaultAsyncExeuteWithThreads)
ValidateTensor(workingMemDescriptor2.m_Inputs[0], expectedExecuteval2);
}
-
BOOST_AUTO_TEST_SUITE_END()
} \ No newline at end of file
diff --git a/src/backends/reference/test/RefTensorHandleTests.cpp b/src/backends/reference/test/RefTensorHandleTests.cpp
index 1ef6de9b32..b04d9d6c52 100644
--- a/src/backends/reference/test/RefTensorHandleTests.cpp
+++ b/src/backends/reference/test/RefTensorHandleTests.cpp
@@ -167,6 +167,39 @@ BOOST_AUTO_TEST_CASE(RefTensorHandleSupportsInPlaceComputation)
ARMNN_ASSERT(!(handleFactory.SupportsInPlaceComputation()));
}
+BOOST_AUTO_TEST_CASE(TestManagedConstTensorHandle)
+{
+ // Initialize arguments
+ void* mem = nullptr;
+ TensorInfo info;
+
+ // Use PassthroughCpuTensor as others are abstract
+ auto passThroughHandle = std::make_shared<PassthroughCpuTensorHandle>(info, mem);
+
+ // Test managed handle is initialized with m_Mapped unset and once Map() called its set
+ ManagedConstTensorHandle managedHandle(passThroughHandle);
+ BOOST_CHECK(!managedHandle.IsMapped());
+ managedHandle.Map();
+ BOOST_CHECK(managedHandle.IsMapped());
+
+ // Test it can then be unmapped
+ managedHandle.Unmap();
+ BOOST_CHECK(!managedHandle.IsMapped());
+
+ // Test member function
+ BOOST_CHECK(managedHandle.GetTensorInfo() == info);
+
+ // Test that nullptr tensor handle doesn't get mapped
+ ManagedConstTensorHandle managedHandleNull(nullptr);
+ BOOST_CHECK(!managedHandleNull.IsMapped());
+ BOOST_CHECK_THROW(managedHandleNull.Map(), armnn::Exception);
+ BOOST_CHECK(!managedHandleNull.IsMapped());
+
+ // Check Unmap() when m_Mapped already false
+ managedHandleNull.Unmap();
+ BOOST_CHECK(!managedHandleNull.IsMapped());
+}
+
#if !defined(__ANDROID__)
// Only run these tests on non Android platforms
BOOST_AUTO_TEST_CASE(CheckSourceType)