diff options
Diffstat (limited to 'src/backends/reference/workloads/Resize.cpp')
-rw-r--r-- | src/backends/reference/workloads/Resize.cpp | 51 |
1 files changed, 43 insertions, 8 deletions
diff --git a/src/backends/reference/workloads/Resize.cpp b/src/backends/reference/workloads/Resize.cpp index 0e0bdd7597..3050bae870 100644 --- a/src/backends/reference/workloads/Resize.cpp +++ b/src/backends/reference/workloads/Resize.cpp @@ -25,6 +25,11 @@ inline float Lerp(float a, float b, float w) return w * b + (1.f - w) * a; } +inline double EuclideanDistance(float Xa, float Ya, const unsigned int Xb, const unsigned int Yb) +{ + return std::sqrt(pow(Xa - boost::numeric_cast<float>(Xb), 2) + pow(Ya - boost::numeric_cast<float>(Yb), 2)); +} + }// anonymous namespace void Resize(Decoder<float>& in, @@ -104,20 +109,50 @@ void Resize(Decoder<float>& in, break; } case armnn::ResizeMethod::NearestNeighbor: - default: { - auto distance0 = std::sqrt(pow(fix - boost::numeric_cast<float>(x0), 2) + - pow(fiy - boost::numeric_cast<float>(y0), 2)); - auto distance1 = std::sqrt(pow(fix - boost::numeric_cast<float>(x1), 2) + - pow(fiy - boost::numeric_cast<float>(y1), 2)); - - unsigned int xNearest = distance0 <= distance1? x0 : x1; - unsigned int yNearest = distance0 <= distance1? y0 : y1; + // calculate euclidean distance to the 4 neighbours + auto distance00 = EuclideanDistance(fix, fiy, x0, y0); + auto distance01 = EuclideanDistance(fix, fiy, x0, y1); + auto distance10 = EuclideanDistance(fix, fiy, x1, y0); + auto distance11 = EuclideanDistance(fix, fiy, x1, y1); + + auto minimum = std::min( { distance00, distance01, distance10, distance11 } ); + + unsigned int xNearest = 0; + unsigned int yNearest = 0; + + if (minimum == distance00) + { + xNearest = x0; + yNearest = y0; + } + else if (minimum == distance01) + { + xNearest = x0; + yNearest = y1; + } + else if (minimum == distance10) + { + xNearest = x1; + yNearest = y0; + } + else if (minimum == distance11) + { + xNearest = x1; + yNearest = y1; + } + else + { + throw armnn::InvalidArgumentException("Resize Nearest Neighbor failure"); + } in[dataLayout.GetIndex(inputShape, n, c, yNearest, xNearest)]; interpolatedValue = in.Get(); break; } + default: + throw armnn::InvalidArgumentException("Unknown resize method: " + + std::to_string(static_cast<int>(resizeMethod))); } out[dataLayout.GetIndex(outputShape, n, c, y, x)]; out.Set(interpolatedValue); |