aboutsummaryrefslogtreecommitdiff
path: root/src/armnn
diff options
context:
space:
mode:
authorTeresa Charlin <teresa.charlinreyes@arm.com>2024-04-23 13:43:03 +0100
committerTeresaARM <teresa.charlinreyes@arm.com>2024-04-24 12:34:43 +0000
commit7db70891f2577fa0f65f4088f9587e69463b43e5 (patch)
tree8dfe6673056b54a1272b90196c0477aa611dd1f6 /src/armnn
parenta42e006dbaface834dae7a5f182d67789fb4daf5 (diff)
downloadarmnn-7db70891f2577fa0f65f4088f9587e69463b43e5.tar.gz
IVGCVSW-8602 Move ComputeSplitAxis() to backendsCommon/WorkloadUtils
* Use ComputeSplitAxis in SplitOperator in tosaCommon mappings * Fix TosaRef split tests, that were missing outputInfos Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com> Signed-off-by: Cathal Corbett <cathal.corbett@arm.com> Change-Id: Ib577eacdc6399242f37d25494e208aa56db6334c
Diffstat (limited to 'src/armnn')
-rw-r--r--src/armnn/layers/SplitterLayer.cpp21
1 files changed, 1 insertions, 20 deletions
diff --git a/src/armnn/layers/SplitterLayer.cpp b/src/armnn/layers/SplitterLayer.cpp
index 8a24e0df1f..b04614b31b 100644
--- a/src/armnn/layers/SplitterLayer.cpp
+++ b/src/armnn/layers/SplitterLayer.cpp
@@ -9,6 +9,7 @@
#include <armnn/TypesUtils.hpp>
#include <armnn/backends/WorkloadData.hpp>
#include <armnn/backends/WorkloadFactory.hpp>
+#include <backendsCommon/WorkloadUtils.hpp>
namespace armnn
{
@@ -57,26 +58,6 @@ void SplitterLayer::CreateTensors(const TensorHandleFactoryRegistry& registry,
// check if split is along the x or y (2 innermost dimensions)
auto numberOfDimensions = m_Param.GetNumDimensions();
- // Compute split axis within class as aclCommon function causes header issues when included
- auto ComputeSplitAxis = [&](const armnn::SplitterDescriptor& desc, const TensorShape& input)
- {
- unsigned int numSplit = desc.GetNumViews();
- unsigned int numDimensions = desc.GetNumDimensions();
- std::set<unsigned int> splitAxis;
-
- for (unsigned int i = 0; i < numSplit; ++i)
- {
- for (unsigned int dimIdx = 0; dimIdx < numDimensions; ++dimIdx)
- {
- if (desc.GetViewSizes(i)[dimIdx] != input[dimIdx])
- {
- splitAxis.insert(dimIdx);
- }
- }
- }
- return splitAxis;
- };
-
std::set<unsigned int> axis = ComputeSplitAxis(m_Param, parentInfo.GetShape());
std::set<unsigned int>::iterator axisIt = axis.begin();