aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/Pooling2d.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/Pooling2d.cpp')
-rw-r--r--src/backends/reference/workloads/Pooling2d.cpp79
1 files changed, 55 insertions, 24 deletions
diff --git a/src/backends/reference/workloads/Pooling2d.cpp b/src/backends/reference/workloads/Pooling2d.cpp
index 435671ffad..2bc3b4f213 100644
--- a/src/backends/reference/workloads/Pooling2d.cpp
+++ b/src/backends/reference/workloads/Pooling2d.cpp
@@ -172,9 +172,6 @@ void Pooling2d(Decoder<float>& rInputDecoder,
Accumulator accumulate = GetAccumulator(params.m_PoolType);
Executor execute = GetExecutor(params.m_PoolType);
- TensorShape outputShape = outputInfo.GetShape();
- TensorShape inputShape = inputInfo.GetShape();
-
// Check supported padding methods outside the loop to simplify
// the inner loop.
if (params.m_PaddingMethod != PaddingMethod::Exclude &&
@@ -183,6 +180,8 @@ void Pooling2d(Decoder<float>& rInputDecoder,
throw armnn::InvalidArgumentException("Unsupported padding type");
}
+ const std::vector<float> decodedInputVec = rInputDecoder.DecodeTensor(inputInfo.GetNumElements());
+
for (int n = 0; n < batchSize; n++)
{
for (int c = 0; c < channels; c++)
@@ -221,12 +220,24 @@ void Pooling2d(Decoder<float>& rInputDecoder,
{
result = 0.0f;
- unsigned int outputIndex = dataLayout.GetIndex(outputShape,
- armnn::numeric_cast<unsigned int>(n),
- armnn::numeric_cast<unsigned int>(c),
- armnn::numeric_cast<unsigned int>(yOutput),
- armnn::numeric_cast<unsigned int>(xOutput));
- rOutputEncoder[outputIndex];
+ int outputIndex;
+
+ if(dataLayout.GetDataLayout() == DataLayout::NHWC)
+ {
+ outputIndex = n * heightOutput * widthOutput * channels +
+ yOutput * widthOutput * channels +
+ xOutput * channels +
+ c;
+ }
+ else
+ {
+ outputIndex = n * heightOutput * widthOutput * channels +
+ c * heightOutput * widthOutput +
+ yOutput * widthOutput +
+ xOutput;
+ }
+
+ rOutputEncoder[static_cast<unsigned int>(outputIndex)];
rOutputEncoder.Set(result);
continue;
}
@@ -244,28 +255,48 @@ void Pooling2d(Decoder<float>& rInputDecoder,
{
for (auto xInput = wstart; xInput < wend; xInput++)
{
- unsigned int inputIndex = dataLayout.GetIndex(inputShape,
- armnn::numeric_cast<unsigned int>(n),
- armnn::numeric_cast<unsigned int>(c),
- armnn::numeric_cast<unsigned int>(yInput),
- armnn::numeric_cast<unsigned int>(xInput));
-
- rInputDecoder[inputIndex];
- float inval = rInputDecoder.Get();
- accumulate(result, inval);
+ int inputIndex;
+ if(dataLayout.GetDataLayout() == DataLayout::NHWC)
+ {
+ inputIndex = n * heightInput * widthInput * channels +
+ yInput * widthInput * channels +
+ xInput * channels +
+ c;
+
+ }
+ else
+ {
+ inputIndex = n * heightInput * widthInput * channels +
+ c * heightInput * widthInput +
+ yInput * widthInput +
+ xInput;
+ }
+
+ accumulate(result, decodedInputVec[static_cast<unsigned int>(inputIndex)]);
}
}
execute(result, poolAreaSize);
- unsigned int outputIndex = dataLayout.GetIndex(outputShape,
- armnn::numeric_cast<unsigned int>(n),
- armnn::numeric_cast<unsigned int>(c),
- armnn::numeric_cast<unsigned int>(yOutput),
- armnn::numeric_cast<unsigned int>(xOutput));
+ int outputIndex;
+
+ if(dataLayout.GetDataLayout() == DataLayout::NHWC)
+ {
+ outputIndex = n * heightOutput * widthOutput * channels +
+ yOutput * widthOutput * channels +
+ xOutput * channels +
+ c;
+ }
+ else
+ {
+ outputIndex = n * heightOutput * widthOutput * channels +
+ c * heightOutput * widthOutput +
+ yOutput * widthOutput +
+ xOutput;
+ }
- rOutputEncoder[outputIndex];
+ rOutputEncoder[static_cast<unsigned int>(outputIndex)];
rOutputEncoder.Set(result);
}
}