aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils/Transpose.cpp
blob: a8e4a1cb81a0458822ed66bdb6ab8c433ac57d5b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
//
// Copyright © 2020 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include <armnn/Tensor.hpp>

#include <armnnUtils/Transpose.hpp>

#include "Half.hpp"

#include <cstring>

namespace
{

class TransposeLoop
{
public:
    using size_type = unsigned int;

    TransposeLoop(const armnn::TensorShape& srcShape, const armnn::PermutationVector& mappings)
        : m_SrcShape(srcShape)
    {
        if (srcShape.GetNumDimensions() != mappings.GetSize())
        {
            std::stringstream msg;
            msg << "Transpose: Number of shape dimensions (" << srcShape.GetNumDimensions() <<
                ") does not match the size of the mappings (" << mappings.GetSize() << ")";
            throw armnn::InvalidArgumentException(msg.str());
        }

        const size_type numDims = srcShape.GetNumDimensions();

        size_type srcStride = 1U;
        size_type dstStride = 1U;

        for (size_type i = numDims - 1U, k = 0U; k < numDims; ++k, --i)
        {
            m_SrcStrides[i] = srcStride;
            m_DstStrides[mappings[i]] = dstStride;

            srcStride *= srcShape[i];
            dstStride *= srcShape[mappings[i]];
        }
    }

    void Unroll(const void* srcData, void* dstData, size_t dataTypeSize)
    {
        if (srcData == nullptr)
        {
            throw armnn::Exception("Transpose: Source Data pointer is null");
        }
        if (dstData == nullptr)
        {
            throw armnn::Exception("Transpose: Destination Data pointer is null");
        }
        if (dataTypeSize == 0)
        {
            throw armnn::Exception("Transpose: dataTypeSize is zero");
        }

        const unsigned char* srcDataPtr = reinterpret_cast<const unsigned char*>(srcData);
        unsigned char* dstDataPtr       = reinterpret_cast<unsigned char*>(dstData);

        const unsigned char* const srcEndPtr = srcDataPtr + m_SrcShape.GetNumElements() * dataTypeSize;
        unsigned char* const       dstEndPtr = dstDataPtr + m_SrcShape.GetNumElements() * dataTypeSize;

        Unroll(0, srcDataPtr, dstDataPtr, srcEndPtr, dstEndPtr, dataTypeSize);
    }

private:
    void Unroll(size_type dimension,
                const unsigned char* srcData, unsigned char* dstData,
                const unsigned char* srcEnd, unsigned char* dstEnd,
                size_t dataTypeSize)
    {
        if (srcData == nullptr)
        {
            throw armnn::Exception("Transpose: Source Data pointer is null");
        }
        if (dstData == nullptr)
        {
            throw armnn::Exception("Transpose: Destination Data pointer is null");
        }
        if (srcEnd == nullptr)
        {
            throw armnn::Exception("Transpose: Source End pointer is null");
        }
        if (dstEnd == nullptr)
        {
            throw armnn::Exception("Transpose: Destination End is zero");
        }
        if (dataTypeSize == 0)
        {
            throw armnn::Exception("Transpose: dataTypeSize is invalid");
        }

        if (dimension >= m_SrcShape.GetNumDimensions())
        {
            ::memcpy(dstData, srcData, dataTypeSize);
        }
        else
        {
            for (size_type i = 0; i < m_SrcShape[dimension]; i++)
            {
                Unroll(dimension + 1, srcData, dstData, srcEnd, dstEnd, dataTypeSize);

                srcData += m_SrcStrides[dimension] * dataTypeSize;
                dstData += m_DstStrides[dimension] * dataTypeSize;
            }
        }
    }

    armnn::TensorShape m_SrcShape;
    std::array<size_type, armnn::MaxNumOfTensorDimensions> m_SrcStrides;
    std::array<size_type, armnn::MaxNumOfTensorDimensions> m_DstStrides;
};

} // namespace

namespace armnnUtils
{

armnn::TensorShape TransposeTensorShape(const armnn::TensorShape& srcShape, const armnn::PermutationVector& mappings)
{
    if (srcShape.GetNumDimensions() != mappings.GetSize())
    {
        std::stringstream msg;
        msg << "Transpose: Number of shape dimensions (" << srcShape.GetNumDimensions() <<
            ") does not match the size of the mappings (" << mappings.GetSize() << ")";
        throw armnn::InvalidArgumentException(msg.str());
    }

    const unsigned int numDims = mappings.GetSize();
    unsigned int outDims[armnn::MaxNumOfTensorDimensions];

    for (unsigned int i = 0U; i < numDims; ++i)
    {
        outDims[i] = srcShape[mappings[i]];
    }
    armnn::TensorShape permutedShape(numDims, outDims);
    return permutedShape;
}

armnn::TensorInfo TransposeTensorShape(const armnn::TensorInfo& info, const armnn::PermutationVector& mappings)
{
    armnn::TensorInfo outInfo(info);
    outInfo.SetShape(TransposeTensorShape(info.GetShape(), mappings));
    return outInfo;
}

void Transpose(const armnn::TensorShape& srcShape, const armnn::PermutationVector& mappings,
             const void* src, void* dst, size_t dataTypeSize)
{
    TransposeLoop(srcShape, mappings).Unroll(src, dst, dataTypeSize);
}

} // namespace armnnUtils