ArmNN
 22.08
TensorHandle.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "ITensorHandle.hpp"
9 
10 #include <armnn/TypesUtils.hpp>
11 #include <armnn/utility/Assert.hpp>
13 
14 #include <algorithm>
15 
16 namespace armnn
17 {
18 
19 // Get a TensorShape representing the strides (in bytes) for each dimension
20 // of a tensor, assuming fully packed data with no padding
21 TensorShape GetUnpaddedTensorStrides(const TensorInfo& tensorInfo);
22 
23 // Abstract tensor handles wrapping a readable region of memory, interpreting it as tensor data.
25 {
26 public:
27  template <typename T>
28  const T* GetConstTensor() const
29  {
30  if (armnnUtils::CompatibleTypes<T>(GetTensorInfo().GetDataType()))
31  {
32  return reinterpret_cast<const T*>(m_Memory);
33  }
34  else
35  {
36  throw armnn::Exception("Attempting to get not compatible type tensor!");
37  }
38  }
39 
40  const TensorInfo& GetTensorInfo() const
41  {
42  return m_TensorInfo;
43  }
44 
45  virtual void Manage() override {}
46 
47  virtual ITensorHandle* GetParent() const override { return nullptr; }
48 
49  virtual const void* Map(bool /* blocking = true */) const override { return m_Memory; }
50  virtual void Unmap() const override {}
51 
52  TensorShape GetStrides() const override
53  {
54  return GetUnpaddedTensorStrides(m_TensorInfo);
55  }
56  TensorShape GetShape() const override { return m_TensorInfo.GetShape(); }
57 
58 protected:
59  ConstTensorHandle(const TensorInfo& tensorInfo);
60 
61  void SetConstMemory(const void* mem) { m_Memory = mem; }
62 
63 private:
64  // Only used for testing
65  void CopyOutTo(void *) const override { ARMNN_ASSERT_MSG(false, "Unimplemented"); }
66  void CopyInFrom(const void*) override { ARMNN_ASSERT_MSG(false, "Unimplemented"); }
67 
68  ConstTensorHandle(const ConstTensorHandle& other) = delete;
69  ConstTensorHandle& operator=(const ConstTensorHandle& other) = delete;
70 
71  TensorInfo m_TensorInfo;
72  const void* m_Memory;
73 };
74 
75 template<>
76 const void* ConstTensorHandle::GetConstTensor<void>() const;
77 
78 // Abstract specialization of ConstTensorHandle that allows write access to the same data.
80 {
81 public:
82  template <typename T>
83  T* GetTensor() const
84  {
85  if (armnnUtils::CompatibleTypes<T>(GetTensorInfo().GetDataType()))
86  {
87  return reinterpret_cast<T*>(m_MutableMemory);
88  }
89  else
90  {
91  throw armnn::Exception("Attempting to get not compatible type tensor!");
92  }
93  }
94 
95 protected:
96  TensorHandle(const TensorInfo& tensorInfo);
97 
98  void SetMemory(void* mem)
99  {
100  m_MutableMemory = mem;
101  SetConstMemory(m_MutableMemory);
102  }
103 
104 private:
105 
106  TensorHandle(const TensorHandle& other) = delete;
107  TensorHandle& operator=(const TensorHandle& other) = delete;
108  void* m_MutableMemory;
109 };
110 
111 template <>
112 void* TensorHandle::GetTensor<void>() const;
113 
114 // A TensorHandle that owns the wrapped memory region.
116 {
117 public:
118  explicit ScopedTensorHandle(const TensorInfo& tensorInfo);
119 
120  // Copies contents from Tensor.
121  explicit ScopedTensorHandle(const ConstTensor& tensor);
122 
123  // Copies contents from ConstTensorHandle
124  explicit ScopedTensorHandle(const ConstTensorHandle& tensorHandle);
125 
127  ScopedTensorHandle& operator=(const ScopedTensorHandle& other);
129 
130  virtual void Allocate() override;
131 
132 private:
133  // Only used for testing
134  void CopyOutTo(void* memory) const override;
135  void CopyInFrom(const void* memory) override;
136 
137  void CopyFrom(const ScopedTensorHandle& other);
138  void CopyFrom(const void* srcMemory, unsigned int numBytes);
139 };
140 
141 // A TensorHandle that wraps an already allocated memory region.
142 //
143 // Clients must make sure the passed in memory region stays alive for the lifetime of
144 // the PassthroughTensorHandle instance.
145 //
146 // Note there is no polymorphism to/from ConstPassthroughTensorHandle.
148 {
149 public:
150  PassthroughTensorHandle(const TensorInfo& tensorInfo, void* mem)
151  : TensorHandle(tensorInfo)
152  {
153  SetMemory(mem);
154  }
155 
156  virtual void Allocate() override;
157 };
158 
159 // A ConstTensorHandle that wraps an already allocated memory region.
160 //
161 // This allows users to pass in const memory to a network.
162 // Clients must make sure the passed in memory region stays alive for the lifetime of
163 // the PassthroughTensorHandle instance.
164 //
165 // Note there is no polymorphism to/from PassthroughTensorHandle.
167 {
168 public:
169  ConstPassthroughTensorHandle(const TensorInfo& tensorInfo, const void* mem)
170  : ConstTensorHandle(tensorInfo)
171  {
172  SetConstMemory(mem);
173  }
174 
175  virtual void Allocate() override;
176 };
177 
178 
179 // Template specializations.
180 
181 template <>
182 const void* ConstTensorHandle::GetConstTensor() const;
183 
184 template <>
185 void* TensorHandle::GetTensor() const;
186 
188 {
189 
190 public:
191  explicit ManagedConstTensorHandle(std::shared_ptr<ConstTensorHandle> ptr)
192  : m_Mapped(false)
193  , m_TensorHandle(std::move(ptr)) {};
194 
195  /// RAII Managed resource Unmaps MemoryArea once out of scope
196  const void* Map(bool blocking = true)
197  {
198  if (m_TensorHandle)
199  {
200  auto pRet = m_TensorHandle->Map(blocking);
201  m_Mapped = true;
202  return pRet;
203  }
204  else
205  {
206  throw armnn::Exception("Attempting to Map null TensorHandle");
207  }
208 
209  }
210 
211  // Delete copy constructor as it's unnecessary
212  ManagedConstTensorHandle(const ConstTensorHandle& other) = delete;
213 
214  // Delete copy assignment as it's unnecessary
215  ManagedConstTensorHandle& operator=(const ManagedConstTensorHandle& other) = delete;
216 
217  // Delete move assignment as it's unnecessary
218  ManagedConstTensorHandle& operator=(ManagedConstTensorHandle&& other) noexcept = delete;
219 
221  {
222  // Bias tensor handles need to be initialized empty before entering scope of if statement checking if enabled
223  if (m_TensorHandle)
224  {
225  Unmap();
226  }
227  }
228 
229  void Unmap()
230  {
231  // Only unmap if mapped and TensorHandle exists.
232  if (m_Mapped && m_TensorHandle)
233  {
234  m_TensorHandle->Unmap();
235  m_Mapped = false;
236  }
237  }
238 
239  const TensorInfo& GetTensorInfo() const
240  {
241  return m_TensorHandle->GetTensorInfo();
242  }
243 
244  bool IsMapped() const
245  {
246  return m_Mapped;
247  }
248 
249 private:
250  bool m_Mapped;
251  std::shared_ptr<ConstTensorHandle> m_TensorHandle;
252 };
253 
254 } // namespace armnn
virtual const void * Map(bool) const override
Map the tensor data for access.
const TensorShape & GetShape() const
Definition: Tensor.hpp:191
ConstPassthroughTensorHandle(const TensorInfo &tensorInfo, const void *mem)
void SetMemory(void *mem)
virtual void Allocate()=0
Indicate to the memory manager that this resource is no longer active.
void SetConstMemory(const void *mem)
TensorShape GetShape() const override
Get the number of elements for each dimension ordered from slowest iterating dimension to fastest ite...
const TensorInfo & GetTensorInfo() const
Copyright (c) 2021 ARM Limited and Contributors.
const TensorInfo & GetTensorInfo() const
ConstTensorHandle(const TensorInfo &tensorInfo)
#define ARMNN_ASSERT_MSG(COND, MSG)
Definition: Assert.hpp:15
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition: Tensor.hpp:327
ManagedConstTensorHandle(std::shared_ptr< ConstTensorHandle > ptr)
T * GetTensor() const
TensorShape GetUnpaddedTensorStrides(const TensorInfo &tensorInfo)
virtual void Manage() override
Indicate to the memory manager that this resource is active.
virtual void Unmap() const override
Unmap the tensor data.
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46
TensorShape GetStrides() const override
Get the strides for each dimension ordered from largest to smallest where the smallest value is the s...
PassthroughTensorHandle(const TensorInfo &tensorInfo, void *mem)
virtual ITensorHandle * GetParent() const override
Get the parent tensor if this is a subtensor.
const void * Map(bool blocking=true)
RAII Managed resource Unmaps MemoryArea once out of scope.
const T * GetConstTensor() const