diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadUtils.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadUtils.cpp | 45 |
1 files changed, 45 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadUtils.cpp b/src/backends/backendsCommon/WorkloadUtils.cpp index fcdad3e21b..d2ae16af0c 100644 --- a/src/backends/backendsCommon/WorkloadUtils.cpp +++ b/src/backends/backendsCommon/WorkloadUtils.cpp @@ -10,6 +10,7 @@ #include <armnnUtils/DataLayoutIndexed.hpp> #include <fmt/format.h> +#include <numeric> namespace armnn { @@ -294,4 +295,48 @@ int32_t ConvertMaskToACLFormat(int32_t mask, int32_t numDim) return reversedMask; } +std::map<std::string, unsigned int> CalculateGatherNdKeyIndices(TensorInfo inputInfo0, TensorInfo inputInfo1) +{ + std::vector<unsigned int> paramsShape; + for (unsigned int i = 0; i < inputInfo0.GetNumDimensions(); ++i) + { + paramsShape.push_back(inputInfo0.GetShape()[i]); + } + + std::vector<unsigned int> indicesShape; + for (unsigned int i = 0; i < inputInfo1.GetNumDimensions(); ++i) + { + indicesShape.push_back(inputInfo1.GetShape()[i]); + } + + std::map<std::string, unsigned int> keyIndices; + + // N: number of batches + keyIndices["N"] = 1; + + // ND: number of dimensions that are sliced from params + keyIndices["ND"] = indicesShape.back(); + + // W: number of indices in each batch (all but the last dimension) + keyIndices["W"] = + static_cast<unsigned int>(std::accumulate(std::begin(indicesShape), + std::end(indicesShape) - 1, + 1, + std::multiplies<>() )); + // K: range of each index + keyIndices["K"] = + static_cast<unsigned int>(std::accumulate(std::begin(paramsShape), + std::begin(paramsShape) + static_cast<int>(keyIndices["ND"]), + 1, + std::multiplies<>() )); + // C: number of channels for each index + keyIndices["C"] = + static_cast<unsigned int>(std::accumulate(std::begin(paramsShape) + static_cast<int>(keyIndices["ND"]), + std::end(paramsShape), + 1, + std::multiplies<>() )); + + return keyIndices; +} + } // namespace armnn |