aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadUtils.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadUtils.cpp45
1 files changed, 45 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadUtils.cpp b/src/backends/backendsCommon/WorkloadUtils.cpp
index fcdad3e21b..d2ae16af0c 100644
--- a/src/backends/backendsCommon/WorkloadUtils.cpp
+++ b/src/backends/backendsCommon/WorkloadUtils.cpp
@@ -10,6 +10,7 @@
#include <armnnUtils/DataLayoutIndexed.hpp>
#include <fmt/format.h>
+#include <numeric>
namespace armnn
{
@@ -294,4 +295,48 @@ int32_t ConvertMaskToACLFormat(int32_t mask, int32_t numDim)
return reversedMask;
}
+std::map<std::string, unsigned int> CalculateGatherNdKeyIndices(TensorInfo inputInfo0, TensorInfo inputInfo1)
+{
+ std::vector<unsigned int> paramsShape;
+ for (unsigned int i = 0; i < inputInfo0.GetNumDimensions(); ++i)
+ {
+ paramsShape.push_back(inputInfo0.GetShape()[i]);
+ }
+
+ std::vector<unsigned int> indicesShape;
+ for (unsigned int i = 0; i < inputInfo1.GetNumDimensions(); ++i)
+ {
+ indicesShape.push_back(inputInfo1.GetShape()[i]);
+ }
+
+ std::map<std::string, unsigned int> keyIndices;
+
+ // N: number of batches
+ keyIndices["N"] = 1;
+
+ // ND: number of dimensions that are sliced from params
+ keyIndices["ND"] = indicesShape.back();
+
+ // W: number of indices in each batch (all but the last dimension)
+ keyIndices["W"] =
+ static_cast<unsigned int>(std::accumulate(std::begin(indicesShape),
+ std::end(indicesShape) - 1,
+ 1,
+ std::multiplies<>() ));
+ // K: range of each index
+ keyIndices["K"] =
+ static_cast<unsigned int>(std::accumulate(std::begin(paramsShape),
+ std::begin(paramsShape) + static_cast<int>(keyIndices["ND"]),
+ 1,
+ std::multiplies<>() ));
+ // C: number of channels for each index
+ keyIndices["C"] =
+ static_cast<unsigned int>(std::accumulate(std::begin(paramsShape) + static_cast<int>(keyIndices["ND"]),
+ std::end(paramsShape),
+ 1,
+ std::multiplies<>() ));
+
+ return keyIndices;
+}
+
} // namespace armnn