aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils/TensorUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnUtils/TensorUtils.cpp')
-rw-r--r--src/armnnUtils/TensorUtils.cpp19
1 files changed, 19 insertions, 0 deletions
diff --git a/src/armnnUtils/TensorUtils.cpp b/src/armnnUtils/TensorUtils.cpp
index 2c25eec163..57f823fe13 100644
--- a/src/armnnUtils/TensorUtils.cpp
+++ b/src/armnnUtils/TensorUtils.cpp
@@ -27,5 +27,24 @@ armnn::TensorShape GetTensorShape(unsigned int numberOfBatches,
}
}
+armnn::TensorInfo GetTensorInfo(unsigned int numberOfBatches,
+ unsigned int numberOfChannels,
+ unsigned int height,
+ unsigned int width,
+ const armnn::DataLayout dataLayout,
+ const armnn::DataType dataType)
+{
+ switch (dataLayout)
+ {
+ case armnn::DataLayout::NCHW:
+ return armnn::TensorInfo({numberOfBatches, numberOfChannels, height, width}, dataType);
+ case armnn::DataLayout::NHWC:
+ return armnn::TensorInfo({numberOfBatches, height, width, numberOfChannels}, dataType);
+ default:
+ throw armnn::InvalidArgumentException("Unknown data layout ["
+ + std::to_string(static_cast<int>(dataLayout)) +
+ "]", CHECK_LOCATION());
+ }
}
+}