aboutsummaryrefslogtreecommitdiff
path: root/include/armnn/backends/TensorHandle.hpp
blob: 2e6c8485d186a25dcd2753e265ecf3a9b0e92d90 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
//
// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include "ITensorHandle.hpp"

#include <armnn/TypesUtils.hpp>
#include <armnn/utility/Assert.hpp>
#include <armnnUtils/CompatibleTypes.hpp>

#include <algorithm>

namespace armnn
{

// Get a TensorShape representing the strides (in bytes) for each dimension
// of a tensor, assuming fully packed data with no padding
TensorShape GetUnpaddedTensorStrides(const TensorInfo& tensorInfo);

// Abstract tensor handles wrapping a readable region of memory, interpreting it as tensor data.
class ConstTensorHandle : public ITensorHandle
{
public:
    template <typename T>
    const T* GetConstTensor() const
    {
        if (armnnUtils::CompatibleTypes<T>(GetTensorInfo().GetDataType()))
        {
            return reinterpret_cast<const T*>(m_Memory);
        }
        else
        {
            throw armnn::Exception("Attempting to get not compatible type tensor!");
        }
    }

    const TensorInfo& GetTensorInfo() const
    {
        return m_TensorInfo;
    }

    virtual void Manage() override {}

    virtual ITensorHandle* GetParent() const override { return nullptr; }

    virtual const void* Map(bool /* blocking = true */) const override { return m_Memory; }
    virtual void Unmap() const override {}

    TensorShape GetStrides() const override
    {
        return GetUnpaddedTensorStrides(m_TensorInfo);
    }
    TensorShape GetShape() const override { return m_TensorInfo.GetShape(); }

protected:
    ConstTensorHandle(const TensorInfo& tensorInfo);

    void SetConstMemory(const void* mem) { m_Memory = mem; }

private:
    // Only used for testing
    void CopyOutTo(void *) const override { ARMNN_ASSERT_MSG(false, "Unimplemented"); }
    void CopyInFrom(const void*) override { ARMNN_ASSERT_MSG(false, "Unimplemented"); }

    ConstTensorHandle(const ConstTensorHandle& other) = delete;
    ConstTensorHandle& operator=(const ConstTensorHandle& other) = delete;

    TensorInfo m_TensorInfo;
    const void* m_Memory;
};

template<>
const void* ConstTensorHandle::GetConstTensor<void>() const;

// Abstract specialization of ConstTensorHandle that allows write access to the same data.
class TensorHandle : public ConstTensorHandle
{
public:
    template <typename T>
    T* GetTensor() const
    {
        if (armnnUtils::CompatibleTypes<T>(GetTensorInfo().GetDataType()))
        {
            return reinterpret_cast<T*>(m_MutableMemory);
        }
        else
        {
            throw armnn::Exception("Attempting to get not compatible type tensor!");
        }
    }

protected:
    TensorHandle(const TensorInfo& tensorInfo);

    void SetMemory(void* mem)
    {
        m_MutableMemory = mem;
        SetConstMemory(m_MutableMemory);
    }

private:

    TensorHandle(const TensorHandle& other) = delete;
    TensorHandle& operator=(const TensorHandle& other) = delete;
    void* m_MutableMemory;
};

template <>
void* TensorHandle::GetTensor<void>() const;

// A TensorHandle that owns the wrapped memory region.
class ScopedTensorHandle : public TensorHandle
{
public:
    explicit ScopedTensorHandle(const TensorInfo& tensorInfo);

    // Copies contents from Tensor.
    explicit ScopedTensorHandle(const ConstTensor& tensor);

    // Copies contents from ConstTensorHandle
    explicit ScopedTensorHandle(const ConstTensorHandle& tensorHandle);

    ScopedTensorHandle(const ScopedTensorHandle& other);
    ScopedTensorHandle& operator=(const ScopedTensorHandle& other);
    ~ScopedTensorHandle();

    virtual void Allocate() override;

private:
    // Only used for testing
    void CopyOutTo(void* memory) const override;
    void CopyInFrom(const void* memory) override;

    void CopyFrom(const ScopedTensorHandle& other);
    void CopyFrom(const void* srcMemory, unsigned int numBytes);
};

// A TensorHandle that wraps an already allocated memory region.
//
// Clients must make sure the passed in memory region stays alive for the lifetime of
// the PassthroughTensorHandle instance.
//
// Note there is no polymorphism to/from ConstPassthroughTensorHandle.
class PassthroughTensorHandle : public TensorHandle
{
public:
    PassthroughTensorHandle(const TensorInfo& tensorInfo, void* mem)
    :   TensorHandle(tensorInfo)
    {
        SetMemory(mem);
    }

    virtual void Allocate() override;
};

// A ConstTensorHandle that wraps an already allocated memory region.
//
// This allows users to pass in const memory to a network.
// Clients must make sure the passed in memory region stays alive for the lifetime of
// the PassthroughTensorHandle instance.
//
// Note there is no polymorphism to/from PassthroughTensorHandle.
class ConstPassthroughTensorHandle : public ConstTensorHandle
{
public:
    ConstPassthroughTensorHandle(const TensorInfo& tensorInfo, const void* mem)
    :   ConstTensorHandle(tensorInfo)
    {
        SetConstMemory(mem);
    }

    virtual void Allocate() override;
};


// Template specializations.

template <>
const void* ConstTensorHandle::GetConstTensor() const;

template <>
void* TensorHandle::GetTensor() const;

class ManagedConstTensorHandle
{

public:
    explicit ManagedConstTensorHandle(std::shared_ptr<ConstTensorHandle> ptr)
        : m_Mapped(false)
        , m_TensorHandle(std::move(ptr)) {};

    /// RAII Managed resource Unmaps MemoryArea once out of scope
    const void* Map(bool blocking = true)
    {
        if (m_TensorHandle)
        {
            auto pRet = m_TensorHandle->Map(blocking);
            m_Mapped = true;
            return pRet;
        }
        else
        {
            throw armnn::Exception("Attempting to Map null TensorHandle");
        }

    }

    // Delete copy constructor as it's unnecessary
    ManagedConstTensorHandle(const ConstTensorHandle& other) = delete;

    // Delete copy assignment as it's unnecessary
    ManagedConstTensorHandle& operator=(const ManagedConstTensorHandle& other) = delete;

    // Delete move assignment as it's unnecessary
    ManagedConstTensorHandle& operator=(ManagedConstTensorHandle&& other) noexcept = delete;

    ~ManagedConstTensorHandle()
    {
        // Bias tensor handles need to be initialized empty before entering scope of if statement checking if enabled
        if (m_TensorHandle)
        {
            Unmap();
        }
    }

    void Unmap()
    {
        // Only unmap if mapped and TensorHandle exists.
        if (m_Mapped && m_TensorHandle)
        {
            m_TensorHandle->Unmap();
            m_Mapped = false;
        }
    }

    const TensorInfo& GetTensorInfo() const
    {
        return m_TensorHandle->GetTensorInfo();
    }

    bool IsMapped() const
    {
        return m_Mapped;
    }

private:
    bool m_Mapped;
    std::shared_ptr<ConstTensorHandle> m_TensorHandle;
};

using ConstCpuTensorHandle ARMNN_DEPRECATED_MSG_REMOVAL_DATE("ConstCpuTensorHandle is deprecated, "
                                                "use ConstTensorHandle instead", "22.05") = ConstTensorHandle;
using CpuTensorHandle ARMNN_DEPRECATED_MSG_REMOVAL_DATE("CpuTensorHandle is deprecated, "
                                           "use TensorHandle instead", "22.05") = TensorHandle;
using ScopedCpuTensorHandle ARMNN_DEPRECATED_MSG_REMOVAL_DATE("ScopedCpuTensorHandle is deprecated, "
                                                 "use ScopedTensorHandle instead", "22.05") = ScopedTensorHandle;
using PassthroughCpuTensorHandle ARMNN_DEPRECATED_MSG_REMOVAL_DATE("PassthroughCpuTensorHandle is deprecated, use "
                                                      "PassthroughTensorHandle instead",
                                                      "22.05") = PassthroughTensorHandle;
using ConstPassthroughCpuTensorHandle ARMNN_DEPRECATED_MSG_REMOVAL_DATE("ConstPassthroughCpuTensorHandle is "
                                                           "deprecated, use ConstPassthroughTensorHandle "
                                                           "instead", "22.05") = ConstPassthroughTensorHandle;

} // namespace armnn