aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/ReverseV2Impl.cpp
blob: f6d5fd74d18b19c2d9760b231c6814745b32b30f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
//
// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include "ReverseV2Impl.hpp"

#include <armnn/backends/WorkloadData.hpp>
#include <armnn/Logging.hpp>
#include <armnnUtils/Permute.hpp>

namespace armnn
{

// Get multi-dimensional index for input tensor
std::vector<unsigned int> ReverseGetMultIdx(const unsigned int idx,
                                            unsigned int inputRank,
                                            std::vector<unsigned int>& elementNumInner)
{
    std::vector<unsigned int> indexList(inputRank);

    unsigned int mIdx = idx;

    for (unsigned int iDim = 0; iDim < inputRank; ++iDim)
    {
        indexList[iDim] = static_cast<unsigned int>(mIdx / elementNumInner[iDim]);
        mIdx %= elementNumInner[iDim];
    }

    return indexList;
}

// Get flattened index for output encoder
unsigned int ReverseGetFlatIdx(const std::vector<unsigned int>& idxList,
                               unsigned int inputRank,
                               std::vector<unsigned int>& elementNumInner)
{
    unsigned int idx = 0;

    for (unsigned int iDim = 0; iDim < inputRank; ++iDim)
    {
        idx += idxList[iDim] * elementNumInner[iDim];
    }

    return idx;
}

// Relocate the coordinate to the reversed tensor
unsigned int ReverseRelocateIdx(unsigned int idx,
                                unsigned int inputRank,
                                std::vector<bool>& axisFlag,
                                std::vector<unsigned int>& dimSize,
                                std::vector<unsigned int>& elementNumInner)
{
    // Get the multidimensional index list for input
    auto inputIdxList = ReverseGetMultIdx(idx, inputRank, elementNumInner);

    std::vector<unsigned int> outputIdxList(inputRank);

    // Relocate the input index to the output one
    for (unsigned int iDim = 0; iDim < inputRank; ++iDim)
    {
        if (axisFlag[iDim])
        {
            outputIdxList[iDim] = dimSize[iDim] - inputIdxList[iDim] - 1;
        }
        else
        {
            outputIdxList[iDim] = inputIdxList[iDim];
        }
    }

    // Get the 1-dimensional flattened index for output
    unsigned int outputIdx = ReverseGetFlatIdx(outputIdxList, inputRank, elementNumInner);
    return outputIdx;
}

void ReverseV2(const ReverseV2Descriptor& params,
               const TensorInfo& inputInfo,
               Decoder<float>& inputDecoder,
               Encoder<float>& outputEncoder)
{
    // Empty axis and empty tensor case: copy input to output
    if (params.m_Axis.empty() || inputInfo.GetNumElements() == 0)
    {
        for (unsigned idx = 0; idx < inputInfo.GetNumElements(); idx++)
        {
            float inputValue = inputDecoder.Get();
            inputDecoder += 1;
            outputEncoder.Set(inputValue);
            outputEncoder += 1;
        }
        return;
    }

    unsigned int inputRank = static_cast<unsigned int>(inputInfo.GetNumDimensions());

    std::vector<bool>axisFlag(inputRank, false);
    std::vector<unsigned int>dimSize(inputRank, 0);

    // Make sure the axes are positive
    for (int32_t axisElement: params.m_Axis)
    {
        axisElement = axisElement < 0 ? axisElement + static_cast<int32_t>(inputRank) : axisElement;
        axisFlag[static_cast<uint32_t>(axisElement)] = true;
    }

    const TensorShape &inputShape = inputInfo.GetShape();

    unsigned int elementNum = inputInfo.GetNumElements();
    unsigned int baseDimSize = 1;

    std::vector<unsigned int> elementNumInner;

    // Get the number of element within the specific dimension
    for (unsigned int iDim = 0; iDim < inputRank; ++iDim) {
        dimSize[iDim] = inputShape[iDim];
        baseDimSize *= dimSize[iDim];
        elementNumInner.push_back(static_cast<unsigned int>(elementNum / baseDimSize));
    }

    // Iterate through all elements
    for (unsigned int idx = 0; idx < elementNum; ++idx)
    {
        float inputValue = inputDecoder.Get();
        inputDecoder += 1;
        auto outputIdx = ReverseRelocateIdx(idx, inputRank, axisFlag, dimSize, elementNumInner);
        outputEncoder[outputIdx];
        outputEncoder.Set(inputValue);
    }
}

} // namespace armnn