aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/CpuTensorHandle.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/backends/CpuTensorHandle.hpp')
-rw-r--r--src/armnn/backends/CpuTensorHandle.hpp41
1 files changed, 35 insertions, 6 deletions
diff --git a/src/armnn/backends/CpuTensorHandle.hpp b/src/armnn/backends/CpuTensorHandle.hpp
index 4bf4439083..3376650ec3 100644
--- a/src/armnn/backends/CpuTensorHandle.hpp
+++ b/src/armnn/backends/CpuTensorHandle.hpp
@@ -9,10 +9,12 @@
#include "OutputHandler.hpp"
+#include <algorithm>
+
namespace armnn
{
-// Abstract tensor handle wrapping a CPU-readable region of memory, interpreting it as tensor data.
+// Abstract tensor handles wrapping a CPU-readable region of memory, interpreting it as tensor data.
class ConstCpuTensorHandle : public ITensorHandle
{
public:
@@ -33,6 +35,30 @@ public:
return ITensorHandle::Cpu;
}
+ virtual void Manage() override {}
+
+ virtual ITensorHandle* GetParent() const override { return nullptr; }
+
+ virtual const void* Map(bool /* blocking = true */) const override { return m_Memory; }
+ virtual void Unmap() const override {}
+
+ TensorShape GetStrides() const override
+ {
+ TensorShape shape(m_TensorInfo.GetShape());
+ auto size = GetDataTypeSize(m_TensorInfo.GetDataType());
+ auto runningSize = size;
+ std::vector<unsigned int> strides(shape.GetNumDimensions());
+ auto lastIdx = shape.GetNumDimensions()-1;
+ for (unsigned int i=0; i < lastIdx ; i++)
+ {
+ strides[lastIdx-i] = runningSize;
+ runningSize *= shape[lastIdx-i];
+ }
+ strides[0] = runningSize;
+ return TensorShape(shape.GetNumDimensions(), strides.data());
+ }
+ TensorShape GetShape() const override { return m_TensorInfo.GetShape(); }
+
protected:
ConstCpuTensorHandle(const TensorInfo& tensorInfo);
@@ -46,7 +72,7 @@ private:
const void* m_Memory;
};
-// Abstract specialization of ConstCpuTensorHandle that allows write access to the same data
+// Abstract specialization of ConstCpuTensorHandle that allows write access to the same data.
class CpuTensorHandle : public ConstCpuTensorHandle
{
public:
@@ -79,9 +105,12 @@ class ScopedCpuTensorHandle : public CpuTensorHandle
public:
explicit ScopedCpuTensorHandle(const TensorInfo& tensorInfo);
- // Copies contents from Tensor
+ // Copies contents from Tensor.
explicit ScopedCpuTensorHandle(const ConstTensor& tensor);
+ // Copies contents from ConstCpuTensorHandle
+ explicit ScopedCpuTensorHandle(const ConstCpuTensorHandle& tensorHandle);
+
ScopedCpuTensorHandle(const ScopedCpuTensorHandle& other);
ScopedCpuTensorHandle& operator=(const ScopedCpuTensorHandle& other);
~ScopedCpuTensorHandle();
@@ -98,7 +127,7 @@ private:
// Clients must make sure the passed in memory region stays alive for the lifetime of
// the PassthroughCpuTensorHandle instance.
//
-// Note there is no polymorphism to/from ConstPassthroughCpuTensorHandle
+// Note there is no polymorphism to/from ConstPassthroughCpuTensorHandle.
class PassthroughCpuTensorHandle : public CpuTensorHandle
{
public:
@@ -117,7 +146,7 @@ public:
// Clients must make sure the passed in memory region stays alive for the lifetime of
// the PassthroughCpuTensorHandle instance.
//
-// Note there is no polymorphism to/from PassthroughCpuTensorHandle
+// Note there is no polymorphism to/from PassthroughCpuTensorHandle.
class ConstPassthroughCpuTensorHandle : public ConstCpuTensorHandle
{
public:
@@ -131,7 +160,7 @@ public:
};
-// template specializations
+// Template specializations.
template <>
const void* ConstCpuTensorHandle::GetConstTensor() const;