diff options
Diffstat (limited to 'src/backends/reference/workloads/Pooling2d.cpp')
-rw-r--r-- | src/backends/reference/workloads/Pooling2d.cpp | 79 |
1 files changed, 55 insertions, 24 deletions
diff --git a/src/backends/reference/workloads/Pooling2d.cpp b/src/backends/reference/workloads/Pooling2d.cpp index 435671ffad..2bc3b4f213 100644 --- a/src/backends/reference/workloads/Pooling2d.cpp +++ b/src/backends/reference/workloads/Pooling2d.cpp @@ -172,9 +172,6 @@ void Pooling2d(Decoder<float>& rInputDecoder, Accumulator accumulate = GetAccumulator(params.m_PoolType); Executor execute = GetExecutor(params.m_PoolType); - TensorShape outputShape = outputInfo.GetShape(); - TensorShape inputShape = inputInfo.GetShape(); - // Check supported padding methods outside the loop to simplify // the inner loop. if (params.m_PaddingMethod != PaddingMethod::Exclude && @@ -183,6 +180,8 @@ void Pooling2d(Decoder<float>& rInputDecoder, throw armnn::InvalidArgumentException("Unsupported padding type"); } + const std::vector<float> decodedInputVec = rInputDecoder.DecodeTensor(inputInfo.GetNumElements()); + for (int n = 0; n < batchSize; n++) { for (int c = 0; c < channels; c++) @@ -221,12 +220,24 @@ void Pooling2d(Decoder<float>& rInputDecoder, { result = 0.0f; - unsigned int outputIndex = dataLayout.GetIndex(outputShape, - armnn::numeric_cast<unsigned int>(n), - armnn::numeric_cast<unsigned int>(c), - armnn::numeric_cast<unsigned int>(yOutput), - armnn::numeric_cast<unsigned int>(xOutput)); - rOutputEncoder[outputIndex]; + int outputIndex; + + if(dataLayout.GetDataLayout() == DataLayout::NHWC) + { + outputIndex = n * heightOutput * widthOutput * channels + + yOutput * widthOutput * channels + + xOutput * channels + + c; + } + else + { + outputIndex = n * heightOutput * widthOutput * channels + + c * heightOutput * widthOutput + + yOutput * widthOutput + + xOutput; + } + + rOutputEncoder[static_cast<unsigned int>(outputIndex)]; rOutputEncoder.Set(result); continue; } @@ -244,28 +255,48 @@ void Pooling2d(Decoder<float>& rInputDecoder, { for (auto xInput = wstart; xInput < wend; xInput++) { - unsigned int inputIndex = dataLayout.GetIndex(inputShape, - armnn::numeric_cast<unsigned int>(n), - armnn::numeric_cast<unsigned int>(c), - armnn::numeric_cast<unsigned int>(yInput), - armnn::numeric_cast<unsigned int>(xInput)); - - rInputDecoder[inputIndex]; - float inval = rInputDecoder.Get(); - accumulate(result, inval); + int inputIndex; + if(dataLayout.GetDataLayout() == DataLayout::NHWC) + { + inputIndex = n * heightInput * widthInput * channels + + yInput * widthInput * channels + + xInput * channels + + c; + + } + else + { + inputIndex = n * heightInput * widthInput * channels + + c * heightInput * widthInput + + yInput * widthInput + + xInput; + } + + accumulate(result, decodedInputVec[static_cast<unsigned int>(inputIndex)]); } } execute(result, poolAreaSize); - unsigned int outputIndex = dataLayout.GetIndex(outputShape, - armnn::numeric_cast<unsigned int>(n), - armnn::numeric_cast<unsigned int>(c), - armnn::numeric_cast<unsigned int>(yOutput), - armnn::numeric_cast<unsigned int>(xOutput)); + int outputIndex; + + if(dataLayout.GetDataLayout() == DataLayout::NHWC) + { + outputIndex = n * heightOutput * widthOutput * channels + + yOutput * widthOutput * channels + + xOutput * channels + + c; + } + else + { + outputIndex = n * heightOutput * widthOutput * channels + + c * heightOutput * widthOutput + + yOutput * widthOutput + + xOutput; + } - rOutputEncoder[outputIndex]; + rOutputEncoder[static_cast<unsigned int>(outputIndex)]; rOutputEncoder.Set(result); } } |