diff options
Diffstat (limited to 'src/armnn/test')
-rw-r--r-- | src/armnn/test/TensorHelpers.hpp | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/src/armnn/test/TensorHelpers.hpp b/src/armnn/test/TensorHelpers.hpp index 7f3ac9ec95..f1ab6c99b5 100644 --- a/src/armnn/test/TensorHelpers.hpp +++ b/src/armnn/test/TensorHelpers.hpp @@ -210,3 +210,22 @@ boost::multi_array<T, n> MakeRandomTensor(const armnn::TensorInfo& tensorInfo, int32_t qOffset = tensorInfo.GetQuantizationOffset(); return MakeTensor<T, n>(tensorInfo, QuantizedVector<T>(qScale, qOffset, init)); } + +template<typename T> +armnn::TensorInfo GetTensorInfo(unsigned int numberOfBatches, + unsigned int numberOfChannels, + unsigned int height, + unsigned int width, + const armnn::DataLayoutIndexed& dataLayout) +{ + switch (dataLayout.GetDataLayout()) + { + case armnn::DataLayout::NCHW: + return armnn::TensorInfo({numberOfBatches, numberOfChannels, height, width}, armnn::GetDataType<T>()); + case armnn::DataLayout::NHWC: + return armnn::TensorInfo({numberOfBatches, height, width, numberOfChannels}, armnn::GetDataType<T>()); + default: + throw armnn::InvalidArgumentException("unknown data layout [" + + std::to_string(static_cast<int>(dataLayout.GetDataLayout())) + "]"); + } +} |