aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils/TensorUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnUtils/TensorUtils.cpp')
-rw-r--r--src/armnnUtils/TensorUtils.cpp21
1 files changed, 21 insertions, 0 deletions
diff --git a/src/armnnUtils/TensorUtils.cpp b/src/armnnUtils/TensorUtils.cpp
index 505c9f8588..5b5b2bd6e6 100644
--- a/src/armnnUtils/TensorUtils.cpp
+++ b/src/armnnUtils/TensorUtils.cpp
@@ -55,6 +55,27 @@ TensorInfo GetTensorInfo(unsigned int numberOfBatches,
}
}
+TensorInfo GetTensorInfo(unsigned int numberOfBatches,
+ unsigned int numberOfChannels,
+ unsigned int depth,
+ unsigned int height,
+ unsigned int width,
+ const DataLayout dataLayout,
+ const DataType dataType)
+{
+ switch (dataLayout)
+ {
+ case DataLayout::NDHWC:
+ return TensorInfo({numberOfBatches, depth, height, width, numberOfChannels}, dataType);
+ case DataLayout::NCDHW:
+ return TensorInfo({numberOfBatches, numberOfChannels, depth, height, width}, dataType);
+ default:
+ throw InvalidArgumentException("Unknown data layout ["
+ + std::to_string(static_cast<int>(dataLayout)) +
+ "]", CHECK_LOCATION());
+ }
+}
+
std::pair<float, float> FindMinMax(ITensorHandle* tensorHandle)
{
auto tensor_data = static_cast<const float *>(tensorHandle->Map(true));