diff options
Diffstat (limited to 'src/backends/reference/workloads/Pooling2d.cpp')
-rw-r--r-- | src/backends/reference/workloads/Pooling2d.cpp | 37 |
1 files changed, 22 insertions, 15 deletions
diff --git a/src/backends/reference/workloads/Pooling2d.cpp b/src/backends/reference/workloads/Pooling2d.cpp index a9cac32ced..f2532cac03 100644 --- a/src/backends/reference/workloads/Pooling2d.cpp +++ b/src/backends/reference/workloads/Pooling2d.cpp @@ -4,7 +4,7 @@ // #include "Pooling2d.hpp" -#include "TensorBufferArrayView.hpp" +#include "DataLayoutIndexed.hpp" #include <armnn/Exceptions.hpp> #include <armnn/Types.hpp> @@ -139,14 +139,13 @@ using namespace armnnUtils; namespace armnn { - -void Pooling2d(const float* in, - float* out, +void Pooling2d(Decoder<float>& rInputDecoder, + Encoder<float>& rOutputEncoder, const TensorInfo& inputInfo, const TensorInfo& outputInfo, const Pooling2dDescriptor& params) { - const DataLayoutIndexed dataLayout = params.m_DataLayout; + const DataLayoutIndexed dataLayout(params.m_DataLayout); auto channelsIndex = dataLayout.GetChannelsIndex(); auto heightIndex = dataLayout.GetHeightIndex(); auto widthIndex = dataLayout.GetWidthIndex(); @@ -171,8 +170,8 @@ void Pooling2d(const float* in, Accumulator accumulate = GetAccumulator(params.m_PoolType); Executor execute = GetExecutor(params.m_PoolType); - TensorBufferArrayView<const float> input(inputInfo.GetShape(), in, dataLayout); - TensorBufferArrayView<float> output(outputInfo.GetShape(), out, dataLayout); + TensorShape outputShape = outputInfo.GetShape(); + TensorShape inputShape = inputInfo.GetShape(); // Check supported padding methods outside the loop to simplify // the inner loop. @@ -228,10 +227,14 @@ void Pooling2d(const float* in, { for (auto xInput = wstart; xInput < wend; xInput++) { - float inval = input.Get(boost::numeric_cast<unsigned int>(n), - boost::numeric_cast<unsigned int>(c), - boost::numeric_cast<unsigned int>(yInput), - boost::numeric_cast<unsigned int>(xInput)); + unsigned int inputIndex = dataLayout.GetIndex(inputShape, + boost::numeric_cast<unsigned int>(n), + boost::numeric_cast<unsigned int>(c), + boost::numeric_cast<unsigned int>(yInput), + boost::numeric_cast<unsigned int>(xInput)); + + rInputDecoder[inputIndex]; + float inval = rInputDecoder.Get(); accumulate(result, inval); } @@ -239,10 +242,14 @@ void Pooling2d(const float* in, execute(result, poolAreaSize); - output.Get(boost::numeric_cast<unsigned int>(n), - boost::numeric_cast<unsigned int>(c), - boost::numeric_cast<unsigned int>(yOutput), - boost::numeric_cast<unsigned int>(xOutput)) = result; + unsigned int outputIndex = dataLayout.GetIndex(outputShape, + boost::numeric_cast<unsigned int>(n), + boost::numeric_cast<unsigned int>(c), + boost::numeric_cast<unsigned int>(yOutput), + boost::numeric_cast<unsigned int>(xOutput)); + + rOutputEncoder[outputIndex]; + rOutputEncoder.Set(result); } } } |