From b9dcfe63b87f024c6f8c5f4b68447de04119dc19 Mon Sep 17 00:00:00 2001 From: Finn Williams Date: Thu, 17 Sep 2020 15:58:31 +0100 Subject: IVGCVSW-5325 Speed up the reference backend Change-Id: Id8bd0a0418be31d975b944b54bbacb25051ffb2e Signed-off-by: Finn Williams --- src/backends/reference/workloads/Pooling2d.cpp | 79 ++++++++++++++++++-------- 1 file changed, 55 insertions(+), 24 deletions(-) (limited to 'src/backends/reference/workloads/Pooling2d.cpp') diff --git a/src/backends/reference/workloads/Pooling2d.cpp b/src/backends/reference/workloads/Pooling2d.cpp index 435671ffad..2bc3b4f213 100644 --- a/src/backends/reference/workloads/Pooling2d.cpp +++ b/src/backends/reference/workloads/Pooling2d.cpp @@ -172,9 +172,6 @@ void Pooling2d(Decoder& rInputDecoder, Accumulator accumulate = GetAccumulator(params.m_PoolType); Executor execute = GetExecutor(params.m_PoolType); - TensorShape outputShape = outputInfo.GetShape(); - TensorShape inputShape = inputInfo.GetShape(); - // Check supported padding methods outside the loop to simplify // the inner loop. if (params.m_PaddingMethod != PaddingMethod::Exclude && @@ -183,6 +180,8 @@ void Pooling2d(Decoder& rInputDecoder, throw armnn::InvalidArgumentException("Unsupported padding type"); } + const std::vector decodedInputVec = rInputDecoder.DecodeTensor(inputInfo.GetNumElements()); + for (int n = 0; n < batchSize; n++) { for (int c = 0; c < channels; c++) @@ -221,12 +220,24 @@ void Pooling2d(Decoder& rInputDecoder, { result = 0.0f; - unsigned int outputIndex = dataLayout.GetIndex(outputShape, - armnn::numeric_cast(n), - armnn::numeric_cast(c), - armnn::numeric_cast(yOutput), - armnn::numeric_cast(xOutput)); - rOutputEncoder[outputIndex]; + int outputIndex; + + if(dataLayout.GetDataLayout() == DataLayout::NHWC) + { + outputIndex = n * heightOutput * widthOutput * channels + + yOutput * widthOutput * channels + + xOutput * channels + + c; + } + else + { + outputIndex = n * heightOutput * widthOutput * channels + + c * heightOutput * widthOutput + + yOutput * widthOutput + + xOutput; + } + + rOutputEncoder[static_cast(outputIndex)]; rOutputEncoder.Set(result); continue; } @@ -244,28 +255,48 @@ void Pooling2d(Decoder& rInputDecoder, { for (auto xInput = wstart; xInput < wend; xInput++) { - unsigned int inputIndex = dataLayout.GetIndex(inputShape, - armnn::numeric_cast(n), - armnn::numeric_cast(c), - armnn::numeric_cast(yInput), - armnn::numeric_cast(xInput)); - - rInputDecoder[inputIndex]; - float inval = rInputDecoder.Get(); - accumulate(result, inval); + int inputIndex; + if(dataLayout.GetDataLayout() == DataLayout::NHWC) + { + inputIndex = n * heightInput * widthInput * channels + + yInput * widthInput * channels + + xInput * channels + + c; + + } + else + { + inputIndex = n * heightInput * widthInput * channels + + c * heightInput * widthInput + + yInput * widthInput + + xInput; + } + + accumulate(result, decodedInputVec[static_cast(inputIndex)]); } } execute(result, poolAreaSize); - unsigned int outputIndex = dataLayout.GetIndex(outputShape, - armnn::numeric_cast(n), - armnn::numeric_cast(c), - armnn::numeric_cast(yOutput), - armnn::numeric_cast(xOutput)); + int outputIndex; + + if(dataLayout.GetDataLayout() == DataLayout::NHWC) + { + outputIndex = n * heightOutput * widthOutput * channels + + yOutput * widthOutput * channels + + xOutput * channels + + c; + } + else + { + outputIndex = n * heightOutput * widthOutput * channels + + c * heightOutput * widthOutput + + yOutput * widthOutput + + xOutput; + } - rOutputEncoder[outputIndex]; + rOutputEncoder[static_cast(outputIndex)]; rOutputEncoder.Set(result); } } -- cgit v1.2.1