aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/Pooling2d.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/Pooling2d.cpp')
-rw-r--r--src/backends/reference/workloads/Pooling2d.cpp37
1 files changed, 22 insertions, 15 deletions
diff --git a/src/backends/reference/workloads/Pooling2d.cpp b/src/backends/reference/workloads/Pooling2d.cpp
index a9cac32ced..f2532cac03 100644
--- a/src/backends/reference/workloads/Pooling2d.cpp
+++ b/src/backends/reference/workloads/Pooling2d.cpp
@@ -4,7 +4,7 @@
//
#include "Pooling2d.hpp"
-#include "TensorBufferArrayView.hpp"
+#include "DataLayoutIndexed.hpp"
#include <armnn/Exceptions.hpp>
#include <armnn/Types.hpp>
@@ -139,14 +139,13 @@ using namespace armnnUtils;
namespace armnn
{
-
-void Pooling2d(const float* in,
- float* out,
+void Pooling2d(Decoder<float>& rInputDecoder,
+ Encoder<float>& rOutputEncoder,
const TensorInfo& inputInfo,
const TensorInfo& outputInfo,
const Pooling2dDescriptor& params)
{
- const DataLayoutIndexed dataLayout = params.m_DataLayout;
+ const DataLayoutIndexed dataLayout(params.m_DataLayout);
auto channelsIndex = dataLayout.GetChannelsIndex();
auto heightIndex = dataLayout.GetHeightIndex();
auto widthIndex = dataLayout.GetWidthIndex();
@@ -171,8 +170,8 @@ void Pooling2d(const float* in,
Accumulator accumulate = GetAccumulator(params.m_PoolType);
Executor execute = GetExecutor(params.m_PoolType);
- TensorBufferArrayView<const float> input(inputInfo.GetShape(), in, dataLayout);
- TensorBufferArrayView<float> output(outputInfo.GetShape(), out, dataLayout);
+ TensorShape outputShape = outputInfo.GetShape();
+ TensorShape inputShape = inputInfo.GetShape();
// Check supported padding methods outside the loop to simplify
// the inner loop.
@@ -228,10 +227,14 @@ void Pooling2d(const float* in,
{
for (auto xInput = wstart; xInput < wend; xInput++)
{
- float inval = input.Get(boost::numeric_cast<unsigned int>(n),
- boost::numeric_cast<unsigned int>(c),
- boost::numeric_cast<unsigned int>(yInput),
- boost::numeric_cast<unsigned int>(xInput));
+ unsigned int inputIndex = dataLayout.GetIndex(inputShape,
+ boost::numeric_cast<unsigned int>(n),
+ boost::numeric_cast<unsigned int>(c),
+ boost::numeric_cast<unsigned int>(yInput),
+ boost::numeric_cast<unsigned int>(xInput));
+
+ rInputDecoder[inputIndex];
+ float inval = rInputDecoder.Get();
accumulate(result, inval);
}
@@ -239,10 +242,14 @@ void Pooling2d(const float* in,
execute(result, poolAreaSize);
- output.Get(boost::numeric_cast<unsigned int>(n),
- boost::numeric_cast<unsigned int>(c),
- boost::numeric_cast<unsigned int>(yOutput),
- boost::numeric_cast<unsigned int>(xOutput)) = result;
+ unsigned int outputIndex = dataLayout.GetIndex(outputShape,
+ boost::numeric_cast<unsigned int>(n),
+ boost::numeric_cast<unsigned int>(c),
+ boost::numeric_cast<unsigned int>(yOutput),
+ boost::numeric_cast<unsigned int>(xOutput));
+
+ rOutputEncoder[outputIndex];
+ rOutputEncoder.Set(result);
}
}
}