aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/ResizeBilinear.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/ResizeBilinear.cpp')
-rw-r--r--src/backends/reference/workloads/ResizeBilinear.cpp24
1 files changed, 17 insertions, 7 deletions
diff --git a/src/backends/reference/workloads/ResizeBilinear.cpp b/src/backends/reference/workloads/ResizeBilinear.cpp
index 2d1087c9a0..70a051492a 100644
--- a/src/backends/reference/workloads/ResizeBilinear.cpp
+++ b/src/backends/reference/workloads/ResizeBilinear.cpp
@@ -27,9 +27,9 @@ inline float Lerp(float a, float b, float w)
}
-void ResizeBilinear(const float* in,
+void ResizeBilinear(Decoder<float>& in,
const TensorInfo& inputInfo,
- float* out,
+ Encoder<float>& out,
const TensorInfo& outputInfo,
DataLayoutIndexed dataLayout)
{
@@ -50,8 +50,8 @@ void ResizeBilinear(const float* in,
const float scaleY = boost::numeric_cast<float>(inputHeight) / boost::numeric_cast<float>(outputHeight);
const float scaleX = boost::numeric_cast<float>(inputWidth) / boost::numeric_cast<float>(outputWidth);
- TensorBufferArrayView<const float> input(inputInfo.GetShape(), in, dataLayout);
- TensorBufferArrayView<float> output(outputInfo.GetShape(), out, dataLayout);
+ TensorShape inputShape = inputInfo.GetShape();
+ TensorShape outputShape = outputInfo.GetShape();
for (unsigned int n = 0; n < batchSize; ++n)
{
@@ -84,11 +84,21 @@ void ResizeBilinear(const float* in,
const unsigned int y1 = std::min(y0 + 1, inputHeight - 1u);
// Interpolation
- const float ly0 = Lerp(input.Get(n, c, y0, x0), input.Get(n, c, y0, x1), xw); // lerp along row y0.
- const float ly1 = Lerp(input.Get(n, c, y1, x0), input.Get(n, c, y1, x1), xw); // lerp along row y1.
+ in[dataLayout.GetIndex(inputShape, n, c, y0, x0)];
+ float input1 = in.Get();
+ in[dataLayout.GetIndex(inputShape, n, c, y0, x1)];
+ float input2 = in.Get();
+ in[dataLayout.GetIndex(inputShape, n, c, y1, x0)];
+ float input3 = in.Get();
+ in[dataLayout.GetIndex(inputShape, n, c, y1, x1)];
+ float input4 = in.Get();
+
+ const float ly0 = Lerp(input1, input2, xw); // lerp along row y0.
+ const float ly1 = Lerp(input3, input4, xw); // lerp along row y1.
const float l = Lerp(ly0, ly1, yw);
- output.Get(n, c, y, x) = l;
+ out[dataLayout.GetIndex(outputShape, n, c, y, x)];
+ out.Set(l);
}
}
}