diff options
Diffstat (limited to 'src/backends/aclCommon')
-rw-r--r-- | src/backends/aclCommon/ArmComputeUtils.hpp | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/src/backends/aclCommon/ArmComputeUtils.hpp b/src/backends/aclCommon/ArmComputeUtils.hpp index b4673f7b31..5b8f983ecc 100644 --- a/src/backends/aclCommon/ArmComputeUtils.hpp +++ b/src/backends/aclCommon/ArmComputeUtils.hpp @@ -9,6 +9,8 @@ #include <arm_compute/core/Types.h> +#include <boost/assert.hpp> + namespace armnn { @@ -130,4 +132,23 @@ inline unsigned int ComputeSoftmaxAclAxis(const armnn::TensorInfo& tensor) return dim - 1; } +inline std::set<unsigned int> ComputeSplitAxis(const armnn::SplitterDescriptor& desc, const TensorShape& input) +{ + unsigned int numSplit = desc.GetNumViews(); + unsigned int numDimensions = desc.GetNumDimensions(); + std::set<unsigned int> splitAxis; + + for (unsigned int i = 0; i < numSplit; ++i) + { + for (unsigned int dimIdx = 0; dimIdx < numDimensions; ++dimIdx) + { + if (desc.GetViewSizes(i)[dimIdx] != input[dimIdx]) + { + splitAxis.insert(dimIdx); + } + } + } + return splitAxis; +} + } // namespace armnn |