ArmNN
 21.05
DataLayoutUtils.hpp File Reference
#include <armnn/Tensor.hpp>
#include <armnn/Types.hpp>
#include <armnnUtils/Permute.hpp>

Go to the source code of this file.

Functions

template<typename T >
void PermuteTensorNchwToNhwc (armnn::TensorInfo &tensorInfo, std::vector< T > &tensorData)
 
template<typename T >
void PermuteTensorNhwcToNchw (armnn::TensorInfo &tensorInfo, std::vector< T > &tensorData)
 

Function Documentation

◆ PermuteTensorNchwToNhwc()

void PermuteTensorNchwToNhwc ( armnn::TensorInfo tensorInfo,
std::vector< T > &  tensorData 
)

Definition at line 14 of file DataLayoutUtils.hpp.

References TensorInfo::GetShape(), armnnUtils::Permute(), and armnnUtils::Permuted().

Referenced by TransposeConvolution2dPerAxisQuantTest().

15 {
16  const armnn::PermutationVector nchwToNhwc = { 0, 3, 1, 2 };
17 
18  tensorInfo = armnnUtils::Permuted(tensorInfo, nchwToNhwc);
19 
20  std::vector<T> tmp(tensorData.size());
21  armnnUtils::Permute(tensorInfo.GetShape(), nchwToNhwc, tensorData.data(), tmp.data(), sizeof(T));
22  tensorData = tmp;
23 }
const TensorShape & GetShape() const
Definition: Tensor.hpp:187
void Permute(const armnn::TensorShape &dstShape, const armnn::PermutationVector &mappings, const void *src, void *dst, size_t dataTypeSize)
Definition: Permute.cpp:131
armnn::TensorShape Permuted(const armnn::TensorShape &srcShape, const armnn::PermutationVector &mappings)
Definition: Permute.cpp:98

◆ PermuteTensorNhwcToNchw()

void PermuteTensorNhwcToNchw ( armnn::TensorInfo tensorInfo,
std::vector< T > &  tensorData 
)

Definition at line 26 of file DataLayoutUtils.hpp.

References TensorInfo::GetShape(), armnnUtils::Permute(), and armnnUtils::Permuted().

Referenced by Convolution2dPerAxisQuantTest(), and DepthwiseConvolution2dPerAxisQuantTest().

27 {
28  const armnn::PermutationVector nhwcToNchw = { 0, 2, 3, 1 };
29 
30  tensorInfo = armnnUtils::Permuted(tensorInfo, nhwcToNchw);
31 
32  std::vector<T> tmp(tensorData.size());
33  armnnUtils::Permute(tensorInfo.GetShape(), nhwcToNchw, tensorData.data(), tmp.data(), sizeof(T));
34 
35  tensorData = tmp;
36 }
const TensorShape & GetShape() const
Definition: Tensor.hpp:187
void Permute(const armnn::TensorShape &dstShape, const armnn::PermutationVector &mappings, const void *src, void *dst, size_t dataTypeSize)
Definition: Permute.cpp:131
armnn::TensorShape Permuted(const armnn::TensorShape &srcShape, const armnn::PermutationVector &mappings)
Definition: Permute.cpp:98