diff options
author | Éanna Ó Catháin <eanna.ocathain@arm.com> | 2018-11-12 11:36:34 +0000 |
---|---|---|
committer | Les Bell <les.bell@arm.com> | 2018-11-12 13:08:37 +0000 |
commit | 4e1e136cce3fca73ba49b570cfcb620f4ec574da (patch) | |
tree | 1fe9fcbb6a9dbafc12aa99ac543bc0da636a1cd1 /src/backends/reference/workloads/BatchToSpaceNd.cpp | |
parent | f97debb95cbc7e0bbc60e66e5463ede517cac61b (diff) | |
download | armnn-4e1e136cce3fca73ba49b570cfcb620f4ec574da.tar.gz |
IVGCVSW-2054: BATCH_TO_SPACE_ND Reference implementation and Unit tests.
Change-Id: I13c6728dbb60643d0e086d171225c5d802987f92
Diffstat (limited to 'src/backends/reference/workloads/BatchToSpaceNd.cpp')
-rw-r--r-- | src/backends/reference/workloads/BatchToSpaceNd.cpp | 100 |
1 files changed, 100 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/BatchToSpaceNd.cpp b/src/backends/reference/workloads/BatchToSpaceNd.cpp new file mode 100644 index 0000000000..bedf8418ef --- /dev/null +++ b/src/backends/reference/workloads/BatchToSpaceNd.cpp @@ -0,0 +1,100 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "BatchToSpaceNd.hpp" + +#include "RefWorkloadUtils.hpp" + +#include <armnn/Types.hpp> + +#include <boost/assert.hpp> + +namespace armnn +{ + +inline unsigned int Offset(const TensorShape& shape, unsigned int batch, unsigned int height, unsigned int width, + unsigned int channels, const DataLayoutIndexed& dataLayout) +{ + if (dataLayout.GetDataLayout() == DataLayout::NHWC) + { + return ((batch * shape[dataLayout.GetHeightIndex()] + height) * shape[dataLayout.GetWidthIndex()] + width) * + shape[dataLayout.GetChannelsIndex()] + channels; + } + else + { + return ((batch * shape[dataLayout.GetChannelsIndex()] + channels) * + shape[dataLayout.GetHeightIndex()] + height) * + shape[dataLayout.GetWidthIndex()] + width; + } +} + +void BatchToSpaceNd(const DataLayoutIndexed& dataLayout, + const TensorInfo& inputTensorInfo, + const TensorInfo& outputTensorInfo, + const std::vector<unsigned int>& blockShape, + const std::vector<std::vector<unsigned int>>& cropsData, + const float* inputData, + float* outputData) +{ + TensorShape inputShape = inputTensorInfo.GetShape(); + unsigned int inputNumDims = inputShape.GetNumDimensions(); + if (inputNumDims != 4) + { + throw armnn::InvalidArgumentException("Expected Input with 4 Dimensions"); + } + + TensorShape outputShape = outputTensorInfo.GetShape(); + unsigned int outputNumDims = outputShape.GetNumDimensions(); + if (outputNumDims != 4) + { + throw armnn::InvalidArgumentException("Expected Output with 4 Dimensions"); + } + + const unsigned int inputBatchSize = inputShape[0]; + const unsigned int channels = inputShape[dataLayout.GetChannelsIndex()]; + + const unsigned int outputBatchSize = outputShape[0]; + const unsigned int outputHeight = outputShape[dataLayout.GetHeightIndex()]; + const unsigned int outputWidth = outputShape[dataLayout.GetWidthIndex()]; + + const unsigned int blockShapeHeight = blockShape[0]; + const unsigned int blockShapeWidth = blockShape[1]; + + const unsigned int cropsTop = cropsData[0][0]; + const unsigned int cropsLeft = cropsData[1][0]; + + for (unsigned int inBatch = 0; inBatch < inputBatchSize; ++inBatch) + { + const unsigned int outBatch = inBatch % outputBatchSize; + const unsigned int spatialOffset = inBatch / outputBatchSize; + + for (unsigned int inH = 0; inH < inputTensorInfo.GetShape()[dataLayout.GetHeightIndex()]; ++inH) { + const unsigned int outH = inH * blockShapeHeight + spatialOffset / blockShapeWidth - cropsTop; + + if (outH >= outputHeight) + { + continue; + } + + for (unsigned int inW = 0; inW < inputTensorInfo.GetShape()[dataLayout.GetWidthIndex()]; ++inW) { + const unsigned int outW = inW * blockShapeWidth + spatialOffset % blockShapeWidth - cropsLeft; + + if (outW >= outputWidth) + { + continue; + } + + for (unsigned int c = 0; c < channels; c++) + { + unsigned int outOffset = Offset(outputShape, outBatch, outH, outW, c, dataLayout); + unsigned int inOffset = Offset(inputShape, inBatch, inH, inW, c, dataLayout); + outputData[outOffset] = inputData[inOffset]; + } + } + } + } +} + +} //namespace armnn |