ArmNN  NotReleased
CpuTensorHandle.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
10 
11 #include <armnn/TypesUtils.hpp>
12 
13 #include <CompatibleTypes.hpp>
14 
15 #include <algorithm>
16 
17 #include <boost/assert.hpp>
18 
19 namespace armnn
20 {
21 
22 // Get a TensorShape representing the strides (in bytes) for each dimension
23 // of a tensor, assuming fully packed data with no padding
24 TensorShape GetUnpaddedTensorStrides(const TensorInfo& tensorInfo);
25 
26 // Abstract tensor handles wrapping a CPU-readable region of memory, interpreting it as tensor data.
28 {
29 public:
30  template <typename T>
31  const T* GetConstTensor() const
32  {
33  BOOST_ASSERT(CompatibleTypes<T>(GetTensorInfo().GetDataType()));
34  return reinterpret_cast<const T*>(m_Memory);
35  }
36 
37  const TensorInfo& GetTensorInfo() const
38  {
39  return m_TensorInfo;
40  }
41 
42  virtual void Manage() override {}
43 
44  virtual ITensorHandle* GetParent() const override { return nullptr; }
45 
46  virtual const void* Map(bool /* blocking = true */) const override { return m_Memory; }
47  virtual void Unmap() const override {}
48 
49  TensorShape GetStrides() const override
50  {
51  return GetUnpaddedTensorStrides(m_TensorInfo);
52  }
53  TensorShape GetShape() const override { return m_TensorInfo.GetShape(); }
54 
55 protected:
56  ConstCpuTensorHandle(const TensorInfo& tensorInfo);
57 
58  void SetConstMemory(const void* mem) { m_Memory = mem; }
59 
60 private:
61  // Only used for testing
62  void CopyOutTo(void *) const override { BOOST_ASSERT_MSG(false, "Unimplemented"); }
63  void CopyInFrom(const void*) override { BOOST_ASSERT_MSG(false, "Unimplemented"); }
64 
65  ConstCpuTensorHandle(const ConstCpuTensorHandle& other) = delete;
66  ConstCpuTensorHandle& operator=(const ConstCpuTensorHandle& other) = delete;
67 
68  TensorInfo m_TensorInfo;
69  const void* m_Memory;
70 };
71 
72 template<>
73 const void* ConstCpuTensorHandle::GetConstTensor<void>() const;
74 
75 // Abstract specialization of ConstCpuTensorHandle that allows write access to the same data.
77 {
78 public:
79  template <typename T>
80  T* GetTensor() const
81  {
82  BOOST_ASSERT(CompatibleTypes<T>(GetTensorInfo().GetDataType()));
83  return reinterpret_cast<T*>(m_MutableMemory);
84  }
85 
86 protected:
87  CpuTensorHandle(const TensorInfo& tensorInfo);
88 
89  void SetMemory(void* mem)
90  {
91  m_MutableMemory = mem;
92  SetConstMemory(m_MutableMemory);
93  }
94 
95 private:
96 
97  CpuTensorHandle(const CpuTensorHandle& other) = delete;
98  CpuTensorHandle& operator=(const CpuTensorHandle& other) = delete;
99  void* m_MutableMemory;
100 };
101 
102 template <>
103 void* CpuTensorHandle::GetTensor<void>() const;
104 
105 // A CpuTensorHandle that owns the wrapped memory region.
107 {
108 public:
109  explicit ScopedCpuTensorHandle(const TensorInfo& tensorInfo);
110 
111  // Copies contents from Tensor.
112  explicit ScopedCpuTensorHandle(const ConstTensor& tensor);
113 
114  // Copies contents from ConstCpuTensorHandle
115  explicit ScopedCpuTensorHandle(const ConstCpuTensorHandle& tensorHandle);
116 
118  ScopedCpuTensorHandle& operator=(const ScopedCpuTensorHandle& other);
120 
121  virtual void Allocate() override;
122 
123 private:
124  // Only used for testing
125  void CopyOutTo(void* memory) const override;
126  void CopyInFrom(const void* memory) override;
127 
128  void CopyFrom(const ScopedCpuTensorHandle& other);
129  void CopyFrom(const void* srcMemory, unsigned int numBytes);
130 };
131 
132 // A CpuTensorHandle that wraps an already allocated memory region.
133 //
134 // Clients must make sure the passed in memory region stays alive for the lifetime of
135 // the PassthroughCpuTensorHandle instance.
136 //
137 // Note there is no polymorphism to/from ConstPassthroughCpuTensorHandle.
139 {
140 public:
141  PassthroughCpuTensorHandle(const TensorInfo& tensorInfo, void* mem)
142  : CpuTensorHandle(tensorInfo)
143  {
144  SetMemory(mem);
145  }
146 
147  virtual void Allocate() override;
148 };
149 
150 // A ConstCpuTensorHandle that wraps an already allocated memory region.
151 //
152 // This allows users to pass in const memory to a network.
153 // Clients must make sure the passed in memory region stays alive for the lifetime of
154 // the PassthroughCpuTensorHandle instance.
155 //
156 // Note there is no polymorphism to/from PassthroughCpuTensorHandle.
158 {
159 public:
160  ConstPassthroughCpuTensorHandle(const TensorInfo& tensorInfo, const void* mem)
161  : ConstCpuTensorHandle(tensorInfo)
162  {
163  SetConstMemory(mem);
164  }
165 
166  virtual void Allocate() override;
167 };
168 
169 
170 // Template specializations.
171 
172 template <>
173 const void* ConstCpuTensorHandle::GetConstTensor() const;
174 
175 template <>
176 void* CpuTensorHandle::GetTensor() const;
177 
178 } // namespace armnn
const T * GetConstTensor() const
const TensorInfo & GetTensorInfo() const
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition: Tensor.hpp:199
ConstCpuTensorHandle(const TensorInfo &tensorInfo)
virtual void Manage() override
PassthroughCpuTensorHandle(const TensorInfo &tensorInfo, void *mem)
ConstPassthroughCpuTensorHandle(const TensorInfo &tensorInfo, const void *mem)
TensorShape GetStrides() const override
void SetMemory(void *mem)
TensorShape GetShape() const override
virtual ITensorHandle * GetParent() const override
TensorShape GetUnpaddedTensorStrides(const TensorInfo &tensorInfo)
void SetConstMemory(const void *mem)
virtual const void * Map(bool) const override
virtual void Allocate()=0
virtual void Unmap() const override
Unmap the tensor data.
const TensorShape & GetShape() const
Definition: Tensor.hpp:88