diff options
Diffstat (limited to 'src/backends/reference/workloads/SpaceToBatchNd.cpp')
-rw-r--r-- | src/backends/reference/workloads/SpaceToBatchNd.cpp | 65 |
1 files changed, 41 insertions, 24 deletions
diff --git a/src/backends/reference/workloads/SpaceToBatchNd.cpp b/src/backends/reference/workloads/SpaceToBatchNd.cpp index b6bab17367..c3f022c6a6 100644 --- a/src/backends/reference/workloads/SpaceToBatchNd.cpp +++ b/src/backends/reference/workloads/SpaceToBatchNd.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017-2019,2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -19,15 +19,29 @@ unsigned int GetOffset(const TensorShape& shape, unsigned int c, const DataLayoutIndexed& dataLayout) { - if (dataLayout.GetDataLayout() == DataLayout::NHWC) + // 3D Tensors + unsigned int channelDimension3D = dataLayout.GetDataLayout() == DataLayout::NCHW ? 1 : 2; + if (shape.GetNumDimensions() == 3) { - return ((b * shape[dataLayout.GetHeightIndex()] + h) * shape[dataLayout.GetWidthIndex()] + w) * - shape[dataLayout.GetChannelsIndex()] + c; + return (b * shape[dataLayout.GetHeightIndex()] + h) * shape[channelDimension3D] + c; + } + // 4D Tensors + else if (shape.GetNumDimensions() == 4) + { + if (dataLayout.GetDataLayout() == DataLayout::NHWC) + { + return ((b * shape[dataLayout.GetHeightIndex()] + h) * shape[dataLayout.GetWidthIndex()] + w) * + shape[dataLayout.GetChannelsIndex()] + c; + } + else + { + return ((b * shape[dataLayout.GetChannelsIndex()] + c) * shape[dataLayout.GetHeightIndex()] + h) * + shape[dataLayout.GetWidthIndex()] + w; + } } else { - return ((b * shape[dataLayout.GetChannelsIndex()] + c) * shape[dataLayout.GetHeightIndex()] + h) * - shape[dataLayout.GetWidthIndex()] + w; + throw InvalidArgumentException("Tensor rank must be either 3 or 4", CHECK_LOCATION()); } } @@ -37,37 +51,46 @@ void SpaceToBatchNd(const TensorInfo& inputInfo, Decoder<float>& inputData, Encoder<float>& outputData) { + unsigned int rank = inputInfo.GetNumDimensions(); + if (rank != 3 && rank != 4 ) + { + throw InvalidArgumentException("Tensor rank must be either 3 or 4, but it is " + std::to_string(rank), + CHECK_LOCATION()); + } + DataLayoutIndexed dataLayout = params.m_DataLayout; + unsigned int channelDimension3D = params.m_DataLayout == DataLayout::NCHW ? 1 : 2; const TensorShape& inputShape = inputInfo.GetShape(); const TensorShape& outputShape = outputInfo.GetShape(); - const unsigned int channels = inputShape[dataLayout.GetChannelsIndex()]; + const unsigned int inputBatchSize = inputShape[0]; + const unsigned int outputBatchSize = outputShape[0]; - const unsigned int inputBatchSize = inputShape[0]; - const unsigned int inputHeight = inputShape[dataLayout.GetHeightIndex()]; - const unsigned int inputWidth = inputShape[dataLayout.GetWidthIndex()]; + const unsigned int channels = (rank == 3) ? inputShape[channelDimension3D] + : inputShape[dataLayout.GetChannelsIndex()]; - const unsigned int outputBatchSize = outputShape[0]; + const unsigned int inputHeight = inputShape[dataLayout.GetHeightIndex()]; + const unsigned int inputWidth = (rank == 3) ? 1 : inputShape[dataLayout.GetWidthIndex()]; const unsigned int outputHeight = outputShape[dataLayout.GetHeightIndex()]; - const unsigned int outputWidth = outputShape[dataLayout.GetWidthIndex()]; + const unsigned int outputWidth = (rank == 3) ? 1 : outputShape[dataLayout.GetWidthIndex()]; const unsigned int blockHeight = params.m_BlockShape[0]; - const unsigned int blockWidth = params.m_BlockShape[1]; + const unsigned int blockWidth = (rank == 3) ? 1 : params.m_BlockShape[1]; - const unsigned int paddingTop = params.m_PadList[0].first; - const unsigned int paddingLeft = params.m_PadList[1].first; + const unsigned int paddingTop = params.m_PadList[0].first; + const unsigned int paddingLeft = (rank == 3) ? 0 : params.m_PadList[1].first; - for (unsigned int outB = 0; outB < outputBatchSize; outB++) + for (unsigned int outB = 0; outB < outputBatchSize; ++outB) { unsigned int inB = outB % inputBatchSize; unsigned int shiftW = (outB / inputBatchSize) % blockWidth; unsigned int shiftH = (outB / inputBatchSize) / blockWidth; - for (unsigned int outH = 0; outH < outputHeight; outH++) + for (unsigned int outH = 0; outH < outputHeight; ++outH) { - for (unsigned int outW = 0; outW < outputWidth; outW++) + for (unsigned int outW = 0; outW < outputWidth; ++outW) { if (outH * blockHeight + shiftH < paddingTop || outH * blockHeight + shiftH >= paddingTop + inputHeight || @@ -117,10 +140,4 @@ void SpaceToBatchNd(const TensorInfo& inputInfo, } } -void SpaceToBatchNd(const TensorInfo& inputInfo, - const TensorInfo& outputInfo, - const SpaceToBatchNdDescriptor& params, - Decoder<float>& inputData, - Encoder<float>& outData); - } //namespace armnn |