aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/Resize.cpp
blob: 7bed6c605627658fcc678bdccde09bfd02c3b3cf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
//
// Copyright © 2017, 2024 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include "Resize.hpp"

#include "TensorBufferArrayView.hpp"

#include <armnn/utility/NumericCast.hpp>
#include <armnn/Utils.hpp>

#include <cmath>
#include <algorithm>

using namespace armnnUtils;

namespace
{

inline float Lerp(float a, float b, float w)
{
    return w * b + (1.f - w) * a;
}

inline double EuclideanDistance(float Xa, float Ya, const unsigned int Xb, const unsigned int Yb)
{
    return std::sqrt(pow(Xa - armnn::numeric_cast<float>(Xb), 2) + pow(Ya - armnn::numeric_cast<float>(Yb), 2));
}

inline float CalculateResizeScale(const unsigned int& InputSize,
                                  const unsigned int& OutputSize,
                                  const bool& AlignCorners)
{
    return (AlignCorners && OutputSize > 1)
            ?  armnn::numeric_cast<float>(InputSize - 1) / armnn::numeric_cast<float>(OutputSize - 1)
            :  armnn::numeric_cast<float>(InputSize) / armnn::numeric_cast<float>(OutputSize);
}

inline float PixelScaler(const unsigned int& Pixel,
                         const float& Scale,
                         const bool& HalfPixelCenters,
                         armnn::ResizeMethod& resizeMethod)
{
    // For Half Pixel Centers the Top Left texel is assumed to be at 0.5,0.5
    if (HalfPixelCenters && resizeMethod == armnn::ResizeMethod::Bilinear)
    {
        return (static_cast<float>(Pixel) + 0.5f) * Scale - 0.5f;
    }
    // Nearest Neighbour doesn't need to have 0.5f trimmed off as it will floor the values later
    else if (HalfPixelCenters && resizeMethod == armnn::ResizeMethod::NearestNeighbor)
    {
        return (static_cast<float>(Pixel) + 0.5f) * Scale;
    }
    else
    {
        return static_cast<float>(Pixel) * Scale;
    }
}

}// anonymous namespace

namespace armnn
{
void Resize(Decoder<float>&   in,
            const TensorInfo& inputInfo,
            Encoder<float>&   out,
            const TensorInfo& outputInfo,
            DataLayoutIndexed dataLayout,
            ResizeMethod resizeMethod,
            bool alignCorners,
            bool halfPixelCenters)
{
    // alignCorners and halfPixelCenters cannot both be true
    ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(!(alignCorners && halfPixelCenters),
                                        "Resize: alignCorners and halfPixelCenters cannot both be true");

    // We follow the definition of TensorFlow and AndroidNN: the top-left corner of a texel in the output
    // image is projected into the input image to figure out the interpolants and weights. Note that this
    // will yield different results than if projecting the centre of output texels.

    const unsigned int batchSize = inputInfo.GetShape()[0];
    const unsigned int channelCount = inputInfo.GetShape()[dataLayout.GetChannelsIndex()];

    const unsigned int inputHeight = inputInfo.GetShape()[dataLayout.GetHeightIndex()];
    const unsigned int inputWidth = inputInfo.GetShape()[dataLayout.GetWidthIndex()];
    const unsigned int outputHeight = outputInfo.GetShape()[dataLayout.GetHeightIndex()];
    const unsigned int outputWidth = outputInfo.GetShape()[dataLayout.GetWidthIndex()];

    // How much to scale pixel coordinates in the output image, to get the corresponding pixel coordinates
    // in the input image.
    const float scaleY = CalculateResizeScale(inputHeight, outputHeight, alignCorners);
    const float scaleX = CalculateResizeScale(inputWidth, outputWidth, alignCorners);

    const TensorShape& inputShape =  inputInfo.GetShape();
    const TensorShape& outputShape =  outputInfo.GetShape();

    for (unsigned int n = 0; n < batchSize; ++n)
    {
        for (unsigned int c = 0; c < channelCount; ++c)
        {
            for (unsigned int y = 0; y < outputHeight; ++y)
            {
                // Corresponding real-valued height coordinate in input image.
                float iy = PixelScaler(y, scaleY, halfPixelCenters, resizeMethod);

                // Discrete height coordinate of top-left texel (in the 2x2 texel area used for interpolation).
                const float fiy = (resizeMethod == ResizeMethod::NearestNeighbor && alignCorners) ? armnn::roundf(iy)
                                                                                                  : floorf(iy);
                // Pixel scaling a value with Half Pixel Centers can be negative, if so set to 0
                const unsigned int y0 = static_cast<unsigned int>(std::max(fiy, 0.0f));

                // Interpolation weight (range [0,1]).
                const float yw = iy - fiy;

                for (unsigned int x = 0; x < outputWidth; ++x)
                {
                    // Real-valued and discrete width coordinates in input image.
                    float ix = PixelScaler(x, scaleX, halfPixelCenters, resizeMethod);

                    // Nearest Neighbour uses rounding to align to corners
                    const float fix = resizeMethod == ResizeMethod::NearestNeighbor && alignCorners ? armnn::roundf(ix)
                                                                                                    : floorf(ix);
                    // Pixel scaling a value with Half Pixel Centers can be negative, if so set to 0
                    const unsigned int x0 = static_cast<unsigned int>(std::max(fix, 0.0f));

                    // Interpolation weight (range [0,1]).
                    const float xw = ix - fix;

                    unsigned int x1;
                    unsigned int y1;
                    // Half Pixel Centers uses the scaling to compute a weighted parameter for nearby pixels
                    if (halfPixelCenters)
                    {
                        x1 = std::min(static_cast<unsigned int>(std::ceil(ix)), inputWidth - 1u);
                        y1 = std::min(static_cast<unsigned int>(std::ceil(iy)), inputHeight - 1u);
                    }
                    // Discrete width/height coordinates of texels below and to the right of (x0, y0).
                    else
                    {
                        x1 = std::min(x0 + 1, inputWidth - 1u);
                        y1 = std::min(y0 + 1, inputHeight - 1u);
                    }

                    float interpolatedValue;
                    switch (resizeMethod)
                    {
                        case ResizeMethod::Bilinear:
                        {
                            in[dataLayout.GetIndex(inputShape, n, c, y0, x0)];
                            float input1 = in.Get();
                            in[dataLayout.GetIndex(inputShape, n, c, y0, x1)];
                            float input2 = in.Get();
                            in[dataLayout.GetIndex(inputShape, n, c, y1, x0)];
                            float input3 = in.Get();
                            in[dataLayout.GetIndex(inputShape, n, c, y1, x1)];
                            float input4 = in.Get();

                            const float ly0 = Lerp(input1, input2, xw); // lerp along row y0.
                            const float ly1 = Lerp(input3, input4, xw); // lerp along row y1.
                            interpolatedValue = Lerp(ly0, ly1, yw);
                            break;
                        }
                        case ResizeMethod::NearestNeighbor:
                        {
                            // calculate euclidean distance to the 4 neighbours
                            auto distance00 = EuclideanDistance(fix, fiy, x0, y0);
                            auto distance01 = EuclideanDistance(fix, fiy, x0, y1);
                            auto distance10 = EuclideanDistance(fix, fiy, x1, y0);
                            auto distance11 = EuclideanDistance(fix, fiy, x1, y1);

                            auto minimum = std::min( { distance00, distance01, distance10, distance11 } );

                            unsigned int xNearest = 0;
                            unsigned int yNearest = 0;

                            if (minimum == distance00)
                            {
                               xNearest = x0;
                               yNearest = y0;
                            }
                            else if (minimum == distance01)
                            {
                                xNearest = x0;
                                yNearest = y1;
                            }
                            else if (minimum == distance10)
                            {
                                xNearest = x1;
                                yNearest = y0;
                            }
                            else if (minimum == distance11)
                            {
                                xNearest = x1;
                                yNearest = y1;
                            }
                            else
                            {
                                throw InvalidArgumentException("Resize Nearest Neighbor failure");
                            }

                            in[dataLayout.GetIndex(inputShape, n, c, yNearest, xNearest)];
                            interpolatedValue = in.Get();
                            break;
                        }
                        default:
                            throw InvalidArgumentException("Unknown resize method: " +
                                                            std::to_string(static_cast<int>(resizeMethod)));
                    }
                    out[dataLayout.GetIndex(outputShape, n, c, y, x)];
                    out.Set(interpolatedValue);
                }
            }
        }
    }
}

} //namespace armnn