aboutsummaryrefslogtreecommitdiff
path: root/src/backends/aclCommon/ArmComputeUtils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/aclCommon/ArmComputeUtils.hpp')
-rw-r--r--src/backends/aclCommon/ArmComputeUtils.hpp21
1 files changed, 21 insertions, 0 deletions
diff --git a/src/backends/aclCommon/ArmComputeUtils.hpp b/src/backends/aclCommon/ArmComputeUtils.hpp
index b4673f7b31..5b8f983ecc 100644
--- a/src/backends/aclCommon/ArmComputeUtils.hpp
+++ b/src/backends/aclCommon/ArmComputeUtils.hpp
@@ -9,6 +9,8 @@
#include <arm_compute/core/Types.h>
+#include <boost/assert.hpp>
+
namespace armnn
{
@@ -130,4 +132,23 @@ inline unsigned int ComputeSoftmaxAclAxis(const armnn::TensorInfo& tensor)
return dim - 1;
}
+inline std::set<unsigned int> ComputeSplitAxis(const armnn::SplitterDescriptor& desc, const TensorShape& input)
+{
+ unsigned int numSplit = desc.GetNumViews();
+ unsigned int numDimensions = desc.GetNumDimensions();
+ std::set<unsigned int> splitAxis;
+
+ for (unsigned int i = 0; i < numSplit; ++i)
+ {
+ for (unsigned int dimIdx = 0; dimIdx < numDimensions; ++dimIdx)
+ {
+ if (desc.GetViewSizes(i)[dimIdx] != input[dimIdx])
+ {
+ splitAxis.insert(dimIdx);
+ }
+ }
+ }
+ return splitAxis;
+}
+
} // namespace armnn