aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/SpaceToBatchNd.cpp
blob: c3f022c6a6db36378330dc5b4e92349c27d1983b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
//
// Copyright © 2017-2019,2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include "SpaceToBatchNd.hpp"

#include <armnnUtils/DataLayoutIndexed.hpp>

using namespace armnnUtils;

namespace armnn
{

unsigned int GetOffset(const TensorShape& shape,
                       unsigned int b,
                       unsigned int h,
                       unsigned int w,
                       unsigned int c,
                       const DataLayoutIndexed& dataLayout)
{
    // 3D Tensors
    unsigned int channelDimension3D = dataLayout.GetDataLayout() == DataLayout::NCHW ? 1 : 2;
    if (shape.GetNumDimensions() == 3)
    {
        return (b * shape[dataLayout.GetHeightIndex()] + h) * shape[channelDimension3D] + c;
    }
    // 4D Tensors
    else if (shape.GetNumDimensions() == 4)
    {
        if (dataLayout.GetDataLayout() == DataLayout::NHWC)
        {
            return ((b * shape[dataLayout.GetHeightIndex()] + h) * shape[dataLayout.GetWidthIndex()] + w) *
                   shape[dataLayout.GetChannelsIndex()] + c;
        }
        else
        {
            return ((b * shape[dataLayout.GetChannelsIndex()] + c) * shape[dataLayout.GetHeightIndex()] + h) *
                   shape[dataLayout.GetWidthIndex()] + w;
        }
    }
    else
    {
        throw InvalidArgumentException("Tensor rank must be either 3 or 4", CHECK_LOCATION());
    }
}

void SpaceToBatchNd(const TensorInfo& inputInfo,
                    const TensorInfo& outputInfo,
                    const SpaceToBatchNdDescriptor& params,
                    Decoder<float>& inputData,
                    Encoder<float>& outputData)
{
    unsigned int rank = inputInfo.GetNumDimensions();
    if (rank != 3 && rank != 4 )
    {
        throw InvalidArgumentException("Tensor rank must be either 3 or 4, but it is " + std::to_string(rank),
                                       CHECK_LOCATION());
    }

    DataLayoutIndexed dataLayout = params.m_DataLayout;
    unsigned int channelDimension3D = params.m_DataLayout == DataLayout::NCHW ? 1 : 2;

    const TensorShape& inputShape = inputInfo.GetShape();
    const TensorShape& outputShape = outputInfo.GetShape();

    const unsigned int inputBatchSize  = inputShape[0];
    const unsigned int outputBatchSize = outputShape[0];

    const unsigned int channels = (rank == 3) ? inputShape[channelDimension3D]
                                              : inputShape[dataLayout.GetChannelsIndex()];

    const unsigned int inputHeight  = inputShape[dataLayout.GetHeightIndex()];
    const unsigned int inputWidth   = (rank == 3) ? 1 : inputShape[dataLayout.GetWidthIndex()];
    const unsigned int outputHeight = outputShape[dataLayout.GetHeightIndex()];
    const unsigned int outputWidth  = (rank == 3) ? 1 : outputShape[dataLayout.GetWidthIndex()];

    const unsigned int blockHeight = params.m_BlockShape[0];
    const unsigned int blockWidth  = (rank == 3) ? 1 : params.m_BlockShape[1];

    const unsigned int paddingTop  = params.m_PadList[0].first;
    const unsigned int paddingLeft = (rank == 3) ? 0 : params.m_PadList[1].first;

    for (unsigned int outB = 0; outB < outputBatchSize; ++outB)
    {
        unsigned int inB = outB % inputBatchSize;

        unsigned int shiftW = (outB / inputBatchSize) % blockWidth;
        unsigned int shiftH = (outB / inputBatchSize) / blockWidth;

        for (unsigned int outH = 0; outH < outputHeight; ++outH)
        {
            for (unsigned int outW = 0; outW < outputWidth; ++outW)
            {
                if (outH * blockHeight + shiftH < paddingTop ||
                    outH * blockHeight + shiftH >= paddingTop + inputHeight ||
                    outW * blockWidth + shiftW < paddingLeft ||
                    outW * blockWidth + shiftW >= paddingLeft + inputWidth)
                {
                    for (unsigned int c = 0; c < channels; c++)
                    {
                        unsigned int outOffset = GetOffset(outputShape,
                                                           outB,
                                                           outH,
                                                           outW,
                                                           c,
                                                           dataLayout);
                        outputData += outOffset;
                        outputData.Set(0);
                        outputData -= outOffset;
                    }
                }
                else
                {
                    for (unsigned int c = 0; c < channels; c++)
                    {
                        unsigned int inOffset = GetOffset(inputShape,
                                                          inB,
                                                          (outH * blockHeight + shiftH) - paddingTop,
                                                          (outW * blockWidth + shiftW) - paddingLeft,
                                                          c,
                                                          dataLayout);

                        unsigned int outOffset = GetOffset(outputShape,
                                                           outB,
                                                           outH,
                                                           outW,
                                                           c,
                                                           dataLayout);

                        outputData += outOffset;
                        inputData += inOffset;
                        outputData.Set(inputData.Get());
                        inputData -= inOffset;
                        outputData -= outOffset;
                    }
                }
            }
        }
    }
}

} //namespace armnn