aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils/Permute.cpp
blob: 6deff9016862ae7b6bab185934fb4cfe248ef90d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include "Permute.hpp"

#include "Half.hpp"
#include <armnn/Tensor.hpp>

#include <cassert>
#include <cstring>

namespace
{

class PermuteLoop
{
public:
    using size_type = unsigned int;

    PermuteLoop(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings)
        : m_DstShape(dstShape)
    {
        assert(dstShape.GetNumDimensions() == mappings.GetSize());

        const size_type numDims = dstShape.GetNumDimensions();

        size_type srcStride = 1U;
        size_type dstStride = 1U;

        for (size_type i = numDims - 1U, k = 0U; k < numDims; ++k, --i)
        {
            m_SrcStrides[mappings[i]] = srcStride;
            m_DstStrides[i] = dstStride;

            srcStride *= dstShape[mappings[i]];
            dstStride *= dstShape[i];
        }
    }

    template <typename T>
    void Unroll(const T* srcData, T* dstData)
    {
        const T* const srcEnd = srcData + m_DstShape.GetNumElements();
        T* const       dstEnd = dstData + m_DstShape.GetNumElements();
        Unroll(0, srcData, dstData, srcEnd, dstEnd);
    }

    void Unroll(const void* srcData, void* dstData, size_t dataTypeSize)
    {
        assert(srcData);
        assert(dstData);
        assert(dataTypeSize > 0);

        const unsigned char* srcDataPtr = reinterpret_cast<const unsigned char*>(srcData);
        unsigned char* dstDataPtr       = reinterpret_cast<unsigned char*>(dstData);

        const unsigned char* const srcEndPtr = srcDataPtr + m_DstShape.GetNumElements() * dataTypeSize;
        unsigned char* const       dstEndPtr = dstDataPtr + m_DstShape.GetNumElements() * dataTypeSize;

        Unroll(0, srcDataPtr, dstDataPtr, srcEndPtr, dstEndPtr, dataTypeSize);
    }

private:
    template <typename T>
    void Unroll(size_type dimension, const T* srcData, T* dstData, const T* srcEnd, T* dstEnd)
    {
        assert(srcData);
        assert(dstData);
        assert(srcEnd);
        assert(dstEnd);
        assert(srcData < srcEnd);
        assert(dstData < dstEnd);

        if (dimension >= m_DstShape.GetNumDimensions())
        {
            *dstData = *srcData;
        }
        else
        {
            for (size_type i = 0; i < m_DstShape[dimension]; i++)
            {
                Unroll(dimension + 1, srcData, dstData, srcEnd, dstEnd);

                srcData += m_SrcStrides[dimension];
                dstData += m_DstStrides[dimension];
            }
        }
    }

    void Unroll(size_type dimension,
                const unsigned char* srcData, unsigned char* dstData,
                const unsigned char* srcEnd, unsigned char* dstEnd,
                size_t dataTypeSize)
    {
        assert(srcData);
        assert(dstData);
        assert(srcEnd);
        assert(dstEnd);
        assert(srcData < srcEnd);
        assert(dstData < dstEnd);
        assert(dataTypeSize > 0);

        if (dimension >= m_DstShape.GetNumDimensions())
        {
            ::memcpy(dstData, srcData, dataTypeSize);
        }
        else
        {
            for (size_type i = 0; i < m_DstShape[dimension]; i++)
            {
                Unroll(dimension + 1, srcData, dstData, srcEnd, dstEnd, dataTypeSize);

                srcData += m_SrcStrides[dimension] * dataTypeSize;
                dstData += m_DstStrides[dimension] * dataTypeSize;
            }
        }
    }

    armnn::TensorShape m_DstShape;
    std::array<size_type, armnn::MaxNumOfTensorDimensions> m_SrcStrides;
    std::array<size_type, armnn::MaxNumOfTensorDimensions> m_DstStrides;
};

} // namespace

namespace armnnUtils
{

armnn::TensorShape Permuted(const armnn::TensorShape& srcShape, const armnn::PermutationVector& mappings)
{
    assert(srcShape.GetNumDimensions() == mappings.GetSize());

    const unsigned int numDims = mappings.GetSize();
    unsigned int outDims[armnn::MaxNumOfTensorDimensions];

    for (unsigned int i = 0U; i < numDims; ++i)
    {
        outDims[mappings[i]] = srcShape[i];
    }

    armnn::TensorShape permutedShape(numDims, outDims);
    return permutedShape;
}

armnn::TensorInfo Permuted(const armnn::TensorInfo& info, const armnn::PermutationVector& mappings)
{
    armnn::TensorInfo outInfo(info);
    outInfo.SetShape(Permuted(info.GetShape(), mappings));
    return outInfo;
}

void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
             const void* src, void* dst, size_t dataTypeSize)
{
    PermuteLoop(dstShape, mappings).Unroll(src, dst, dataTypeSize);
}

template <typename T>
void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings, const T* src, T* dst)
{
    PermuteLoop(dstShape, mappings).Unroll(src, dst);
}

// Instantiates for types.
template void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
                      const armnn::Half* src, armnn::Half* dst);
template void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
                      const float* src, float* dst);
template void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
                      const uint8_t* src, uint8_t* dst);
template void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
                      const int32_t* src, int32_t* dst);
template void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
                      const bool* src, bool* dst);

} // namespace armnnUtils