From 7b885b3cce70154596b1994b013ea91527117c26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Ny=C3=ADri?= Date: Tue, 26 Oct 2021 14:47:57 +0100 Subject: IVGCVSW-6509 Front End + Reference Workload implementation Subtask of story: IVGCVSW-6164 Add a Pooling3d FrontEnd and Ref Implementation * Add front end * Add reference workload * Add corresponding unit tests Change-Id: Icce4146dd0a06a1da46a2def00a82d343e171750 Signed-off-by: Tamas Nyiri --- src/armnnUtils/TensorUtils.cpp | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) (limited to 'src/armnnUtils/TensorUtils.cpp') diff --git a/src/armnnUtils/TensorUtils.cpp b/src/armnnUtils/TensorUtils.cpp index 505c9f8588..5b5b2bd6e6 100644 --- a/src/armnnUtils/TensorUtils.cpp +++ b/src/armnnUtils/TensorUtils.cpp @@ -55,6 +55,27 @@ TensorInfo GetTensorInfo(unsigned int numberOfBatches, } } +TensorInfo GetTensorInfo(unsigned int numberOfBatches, + unsigned int numberOfChannels, + unsigned int depth, + unsigned int height, + unsigned int width, + const DataLayout dataLayout, + const DataType dataType) +{ + switch (dataLayout) + { + case DataLayout::NDHWC: + return TensorInfo({numberOfBatches, depth, height, width, numberOfChannels}, dataType); + case DataLayout::NCDHW: + return TensorInfo({numberOfBatches, numberOfChannels, depth, height, width}, dataType); + default: + throw InvalidArgumentException("Unknown data layout [" + + std::to_string(static_cast(dataLayout)) + + "]", CHECK_LOCATION()); + } +} + std::pair FindMinMax(ITensorHandle* tensorHandle) { auto tensor_data = static_cast(tensorHandle->Map(true)); -- cgit v1.2.1