aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/SplitterLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/SplitterLayer.cpp')
-rw-r--r--src/armnn/layers/SplitterLayer.cpp21
1 files changed, 1 insertions, 20 deletions
diff --git a/src/armnn/layers/SplitterLayer.cpp b/src/armnn/layers/SplitterLayer.cpp
index 8a24e0df1f..b04614b31b 100644
--- a/src/armnn/layers/SplitterLayer.cpp
+++ b/src/armnn/layers/SplitterLayer.cpp
@@ -9,6 +9,7 @@
#include <armnn/TypesUtils.hpp>
#include <armnn/backends/WorkloadData.hpp>
#include <armnn/backends/WorkloadFactory.hpp>
+#include <backendsCommon/WorkloadUtils.hpp>
namespace armnn
{
@@ -57,26 +58,6 @@ void SplitterLayer::CreateTensors(const TensorHandleFactoryRegistry& registry,
// check if split is along the x or y (2 innermost dimensions)
auto numberOfDimensions = m_Param.GetNumDimensions();
- // Compute split axis within class as aclCommon function causes header issues when included
- auto ComputeSplitAxis = [&](const armnn::SplitterDescriptor& desc, const TensorShape& input)
- {
- unsigned int numSplit = desc.GetNumViews();
- unsigned int numDimensions = desc.GetNumDimensions();
- std::set<unsigned int> splitAxis;
-
- for (unsigned int i = 0; i < numSplit; ++i)
- {
- for (unsigned int dimIdx = 0; dimIdx < numDimensions; ++dimIdx)
- {
- if (desc.GetViewSizes(i)[dimIdx] != input[dimIdx])
- {
- splitAxis.insert(dimIdx);
- }
- }
- }
- return splitAxis;
- };
-
std::set<unsigned int> axis = ComputeSplitAxis(m_Param, parentInfo.GetShape());
std::set<unsigned int>::iterator axisIt = axis.begin();