diff options
author | telsoa01 <telmo.soares@arm.com> | 2018-03-09 14:13:49 +0000 |
---|---|---|
committer | telsoa01 <telmo.soares@arm.com> | 2018-03-09 14:13:49 +0000 |
commit | 4fcda0101ec3d110c1d6d7bee5c83416b645528a (patch) | |
tree | c9a70aeb2887006160c1b3d265c27efadb7bdbae /src/armnnUtils/Permute.cpp | |
download | armnn-4fcda0101ec3d110c1d6d7bee5c83416b645528a.tar.gz |
Release 18.02
Change-Id: Id3c11dc5ee94ef664374a988fcc6901e9a232fa6
Diffstat (limited to 'src/armnnUtils/Permute.cpp')
-rw-r--r-- | src/armnnUtils/Permute.cpp | 118 |
1 files changed, 118 insertions, 0 deletions
diff --git a/src/armnnUtils/Permute.cpp b/src/armnnUtils/Permute.cpp new file mode 100644 index 0000000000..58e58583fc --- /dev/null +++ b/src/armnnUtils/Permute.cpp @@ -0,0 +1,118 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// + +#include "Permute.hpp" + +#include <armnn/Tensor.hpp> + +#include <cassert> + +namespace +{ + +class PermuteLoop +{ +public: + using size_type = unsigned int; + + PermuteLoop(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings) + : m_DstShape(dstShape) + { + assert(dstShape.GetNumDimensions() == mappings.GetSize()); + + const size_type numDims = dstShape.GetNumDimensions(); + + size_type srcStride = 1U; + size_type dstStride = 1U; + + for (size_type i = numDims - 1U, k = 0U; k < numDims; ++k, --i) + { + m_SrcStrides[mappings[i]] = srcStride; + m_DstStrides[i] = dstStride; + + srcStride *= dstShape[mappings[i]]; + dstStride *= dstShape[i]; + } + } + + template <typename T> + void Unroll(const T* srcData, T* dstData) + { + const T* const srcEnd = srcData + m_DstShape.GetNumElements(); + T* const dstEnd = dstData + m_DstShape.GetNumElements(); + Unroll(0, srcData, dstData, srcEnd, dstEnd); + } + +private: + template <typename T> + void Unroll(size_type dimension, const T* srcData, T* dstData, const T* srcEnd, T* dstEnd) + { + assert(srcData < srcEnd); + assert(dstData < dstEnd); + + if (dimension >= m_DstShape.GetNumDimensions()) + { + *dstData = *srcData; + } + else + { + for (size_type i = 0; i < m_DstShape[dimension]; i++) + { + Unroll(dimension + 1, srcData, dstData, srcEnd, dstEnd); + + srcData += m_SrcStrides[dimension]; + dstData += m_DstStrides[dimension]; + } + } + } + + armnn::TensorShape m_DstShape; + std::array<size_type, armnn::MaxNumOfTensorDimensions> m_SrcStrides; + std::array<size_type, armnn::MaxNumOfTensorDimensions> m_DstStrides; +}; + +} // namespace + +namespace armnnUtils +{ + +armnn::TensorShape Permuted(const armnn::TensorShape& srcShape, const armnn::PermutationVector& mappings) +{ + assert(srcShape.GetNumDimensions() == mappings.GetSize()); + + const unsigned int numDims = mappings.GetSize(); + unsigned int outDims[armnn::MaxNumOfTensorDimensions]; + + for (unsigned int i = 0U; i < numDims; ++i) + { + outDims[mappings[i]] = srcShape[i]; + } + + armnn::TensorShape permutedShape(numDims, outDims); + return permutedShape; +} + +armnn::TensorInfo Permuted(const armnn::TensorInfo& info, const armnn::PermutationVector& mappings) +{ + armnn::TensorInfo outInfo(info); + outInfo.SetShape(Permuted(info.GetShape(), mappings)); + return outInfo; +} + +template <typename T> +void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings, const T* src, T* dst) +{ + PermuteLoop(dstShape, mappings).Unroll(src, dst); +} + +// Instantiate for types +template void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings, + const float* src, float* dst); +template void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings, + const uint8_t* src, uint8_t* dst); +template void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings, + const int32_t* src, int32_t* dst); + +} // namespace armnnUtils |