ArmNN
 22.11
MockTensorHandle.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "armnnTestUtils/MockTensorHandle.hpp"
7 
8 namespace armnn
9 {
10 
11 MockTensorHandle::MockTensorHandle(const TensorInfo& tensorInfo, std::shared_ptr<MockMemoryManager>& memoryManager)
12  : m_TensorInfo(tensorInfo)
13  , m_MemoryManager(memoryManager)
14  , m_Pool(nullptr)
15  , m_UnmanagedMemory(nullptr)
16  , m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined))
17  , m_Imported(false)
18  , m_IsImportEnabled(false)
19 {}
20 
21 MockTensorHandle::MockTensorHandle(const TensorInfo& tensorInfo, MemorySourceFlags importFlags)
22  : m_TensorInfo(tensorInfo)
23  , m_Pool(nullptr)
24  , m_UnmanagedMemory(nullptr)
25  , m_ImportFlags(importFlags)
26  , m_Imported(false)
27  , m_IsImportEnabled(true)
28 {}
29 
30 MockTensorHandle::~MockTensorHandle()
31 {
32  if (!m_Pool)
33  {
34  // unmanaged
35  if (!m_Imported)
36  {
37  ::operator delete(m_UnmanagedMemory);
38  }
39  }
40 }
41 
42 void MockTensorHandle::Manage()
43 {
44  if (!m_IsImportEnabled)
45  {
46  ARMNN_ASSERT_MSG(!m_Pool, "MockTensorHandle::Manage() called twice");
47  ARMNN_ASSERT_MSG(!m_UnmanagedMemory, "MockTensorHandle::Manage() called after Allocate()");
48 
49  m_Pool = m_MemoryManager->Manage(m_TensorInfo.GetNumBytes());
50  }
51 }
52 
53 void MockTensorHandle::Allocate()
54 {
55  // If import is enabled, do not allocate the tensor
56  if (!m_IsImportEnabled)
57  {
58 
59  if (!m_UnmanagedMemory)
60  {
61  if (!m_Pool)
62  {
63  // unmanaged
64  m_UnmanagedMemory = ::operator new(m_TensorInfo.GetNumBytes());
65  }
66  else
67  {
68  m_MemoryManager->Allocate(m_Pool);
69  }
70  }
71  else
72  {
73  throw InvalidArgumentException("MockTensorHandle::Allocate Trying to allocate a MockTensorHandle"
74  "that already has allocated memory.");
75  }
76  }
77 }
78 
79 const void* MockTensorHandle::Map(bool /*unused*/) const
80 {
81  return GetPointer();
82 }
83 
84 void* MockTensorHandle::GetPointer() const
85 {
86  if (m_UnmanagedMemory)
87  {
88  return m_UnmanagedMemory;
89  }
90  else if (m_Pool)
91  {
92  return m_MemoryManager->GetPointer(m_Pool);
93  }
94  else
95  {
96  throw NullPointerException("MockTensorHandle::GetPointer called on unmanaged, unallocated tensor handle");
97  }
98 }
99 
100 void MockTensorHandle::CopyOutTo(void* dest) const
101 {
102  const void* src = GetPointer();
103  ARMNN_ASSERT(src);
104  memcpy(dest, src, m_TensorInfo.GetNumBytes());
105 }
106 
107 void MockTensorHandle::CopyInFrom(const void* src)
108 {
109  void* dest = GetPointer();
110  ARMNN_ASSERT(dest);
111  memcpy(dest, src, m_TensorInfo.GetNumBytes());
112 }
113 
114 bool MockTensorHandle::Import(void* memory, MemorySource source)
115 {
116  if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
117  {
118  if (m_IsImportEnabled && source == MemorySource::Malloc)
119  {
120  // Check memory alignment
121  if (!CanBeImported(memory, source))
122  {
123  if (m_Imported)
124  {
125  m_Imported = false;
126  m_UnmanagedMemory = nullptr;
127  }
128 
129  return false;
130  }
131 
132  // m_UnmanagedMemory not yet allocated.
133  if (!m_Imported && !m_UnmanagedMemory)
134  {
135  m_UnmanagedMemory = memory;
136  m_Imported = true;
137  return true;
138  }
139 
140  // m_UnmanagedMemory initially allocated with Allocate().
141  if (!m_Imported && m_UnmanagedMemory)
142  {
143  return false;
144  }
145 
146  // m_UnmanagedMemory previously imported.
147  if (m_Imported)
148  {
149  m_UnmanagedMemory = memory;
150  return true;
151  }
152  }
153  }
154 
155  return false;
156 }
157 
158 bool MockTensorHandle::CanBeImported(void* memory, MemorySource source)
159 {
160  if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
161  {
162  if (m_IsImportEnabled && source == MemorySource::Malloc)
163  {
164  uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType());
165  if (reinterpret_cast<uintptr_t>(memory) % alignment)
166  {
167  return false;
168  }
169 
170  return true;
171  }
172  }
173  return false;
174 }
175 
176 } // namespace armnn
unsigned int GetNumBytes() const
Definition: Tensor.cpp:427
unsigned int MemorySourceFlags
Copyright (c) 2021 ARM Limited and Contributors.
#define ARMNN_ASSERT_MSG(COND, MSG)
Definition: Assert.hpp:15
DataType GetDataType() const
Definition: Tensor.hpp:198
#define ARMNN_ASSERT(COND)
Definition: Assert.hpp:14
MemorySource
Define the Memory Source to reduce copies.
Definition: Types.hpp:230
constexpr unsigned int GetDataTypeSize(DataType dataType)
Definition: TypesUtils.hpp:151