From d41b25938323455ea6b6d5348cab8861971b5fba Mon Sep 17 00:00:00 2001 From: Nina Drozd Date: Mon, 19 Nov 2018 13:03:36 +0000 Subject: IVGCVSW-2144: Adding TensorUtils class * helper methods for creating TensorShape and TensorInfo objects Change-Id: I371fc7aea08ca6bbb9c205a143ce36e8353a1c48 --- src/armnnUtils/TensorUtils.hpp | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 src/armnnUtils/TensorUtils.hpp (limited to 'src/armnnUtils/TensorUtils.hpp') diff --git a/src/armnnUtils/TensorUtils.hpp b/src/armnnUtils/TensorUtils.hpp new file mode 100644 index 0000000000..6461b37f75 --- /dev/null +++ b/src/armnnUtils/TensorUtils.hpp @@ -0,0 +1,37 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include + +namespace armnnUtils +{ +armnn::TensorShape GetTensorShape(unsigned int numberOfBatches, + unsigned int numberOfChannels, + unsigned int height, + unsigned int width, + const armnn::DataLayout dataLayout); + +template +armnn::TensorInfo GetTensorInfo(unsigned int numberOfBatches, + unsigned int numberOfChannels, + unsigned int height, + unsigned int width, + const armnn::DataLayout dataLayout) +{ + switch (dataLayout) + { + case armnn::DataLayout::NCHW: + return armnn::TensorInfo({numberOfBatches, numberOfChannels, height, width}, armnn::GetDataType()); + case armnn::DataLayout::NHWC: + return armnn::TensorInfo({numberOfBatches, height, width, numberOfChannels}, armnn::GetDataType()); + default: + throw armnn::InvalidArgumentException("Unknown data layout [" + + std::to_string(static_cast(dataLayout)) + + "]", CHECK_LOCATION()); + } +} +} // namespace armnnUtils \ No newline at end of file -- cgit v1.2.1