ArmNN
 20.02
CpuTensorHandle.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include <armnn/Exceptions.hpp>
7 
9 
10 #include <cstring>
11 
12 namespace armnn
13 {
14 
16 {
17  TensorShape shape(tensorInfo.GetShape());
18  auto size = GetDataTypeSize(tensorInfo.GetDataType());
19  auto runningSize = size;
20  std::vector<unsigned int> strides(shape.GetNumDimensions());
21  auto lastIdx = shape.GetNumDimensions()-1;
22  for (unsigned int i=0; i < lastIdx ; i++)
23  {
24  strides[lastIdx-i] = runningSize;
25  runningSize *= shape[lastIdx-i];
26  }
27  strides[0] = runningSize;
28  return TensorShape(shape.GetNumDimensions(), strides.data());
29 }
30 
32 : m_TensorInfo(tensorInfo)
33 , m_Memory(nullptr)
34 {
35 }
36 
37 template <>
38 const void* ConstCpuTensorHandle::GetConstTensor<void>() const
39 {
40  return m_Memory;
41 }
42 
44 : ConstCpuTensorHandle(tensorInfo)
45 , m_MutableMemory(nullptr)
46 {
47 }
48 
49 template <>
50 void* CpuTensorHandle::GetTensor<void>() const
51 {
52  return m_MutableMemory;
53 }
54 
56 : CpuTensorHandle(tensorInfo)
57 {
58 }
59 
61 : ScopedCpuTensorHandle(tensor.GetInfo())
62 {
63  CopyFrom(tensor.GetMemoryArea(), tensor.GetNumBytes());
64 }
65 
67 : ScopedCpuTensorHandle(tensorHandle.GetTensorInfo())
68 {
69  CopyFrom(tensorHandle.GetConstTensor<void>(), tensorHandle.GetTensorInfo().GetNumBytes());
70 }
71 
74 {
75  CopyFrom(other);
76 }
77 
79 {
80  ::operator delete(GetTensor<void>());
81  SetMemory(nullptr);
82  CopyFrom(other);
83  return *this;
84 }
85 
87 {
88  ::operator delete(GetTensor<void>());
89 }
90 
92 {
93  if (GetTensor<void>() == nullptr)
94  {
95  SetMemory(::operator new(GetTensorInfo().GetNumBytes()));
96  }
97  else
98  {
99  throw InvalidArgumentException("CpuTensorHandle::Allocate Trying to allocate a CpuTensorHandle"
100  "that already has allocated memory.");
101  }
102 }
103 
104 void ScopedCpuTensorHandle::CopyOutTo(void* memory) const
105 {
106  memcpy(memory, GetTensor<void>(), GetTensorInfo().GetNumBytes());
107 }
108 
109 void ScopedCpuTensorHandle::CopyInFrom(const void* memory)
110 {
111  memcpy(GetTensor<void>(), memory, GetTensorInfo().GetNumBytes());
112 }
113 
114 void ScopedCpuTensorHandle::CopyFrom(const ScopedCpuTensorHandle& other)
115 {
116  CopyFrom(other.GetTensor<void>(), other.GetTensorInfo().GetNumBytes());
117 }
118 
119 void ScopedCpuTensorHandle::CopyFrom(const void* srcMemory, unsigned int numBytes)
120 {
121  BOOST_ASSERT(GetTensor<void>() == nullptr);
122  BOOST_ASSERT(GetTensorInfo().GetNumBytes() == numBytes);
123 
124  if (srcMemory)
125  {
126  Allocate();
127  memcpy(GetTensor<void>(), srcMemory, numBytes);
128  }
129 }
130 
132 {
133  throw InvalidArgumentException("PassthroughCpuTensorHandle::Allocate() should never be called");
134 }
135 
137 {
138  throw InvalidArgumentException("ConstPassthroughCpuTensorHandle::Allocate() should never be called");
139 }
140 
141 } // namespace armnn
const TensorShape & GetShape() const
Definition: Tensor.hpp:88
unsigned int GetNumBytes() const
Definition: Tensor.cpp:213
void SetMemory(void *mem)
MemoryType GetMemoryArea() const
Definition: Tensor.hpp:177
const T * GetConstTensor() const
Copyright (c) 2020 ARM Limited.
ScopedCpuTensorHandle & operator=(const ScopedCpuTensorHandle &other)
virtual void Allocate() override
Indicate to the memory manager that this resource is no longer active.
DataType GetDataType() const
Definition: Tensor.hpp:95
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition: Tensor.hpp:199
TensorShape GetUnpaddedTensorStrides(const TensorInfo &tensorInfo)
ConstCpuTensorHandle(const TensorInfo &tensorInfo)
virtual void Allocate() override
Indicate to the memory manager that this resource is no longer active.
ScopedCpuTensorHandle(const TensorInfo &tensorInfo)
CpuTensorHandle(const TensorInfo &tensorInfo)
const TensorInfo & GetTensorInfo() const
unsigned int GetNumBytes() const
Definition: Tensor.hpp:174
constexpr unsigned int GetDataTypeSize(DataType dataType)
Definition: TypesUtils.hpp:115
virtual void Allocate() override
Indicate to the memory manager that this resource is no longer active.