diff options
Diffstat (limited to 'src/backends/aclCommon/ArmComputeTensorUtils.cpp')
-rw-r--r-- | src/backends/aclCommon/ArmComputeTensorUtils.cpp | 51 |
1 files changed, 34 insertions, 17 deletions
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.cpp b/src/backends/aclCommon/ArmComputeTensorUtils.cpp index 517b11ced8..38c7f70da5 100644 --- a/src/backends/aclCommon/ArmComputeTensorUtils.cpp +++ b/src/backends/aclCommon/ArmComputeTensorUtils.cpp @@ -1,11 +1,12 @@ // -// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include <aclCommon/ArmComputeTensorUtils.hpp> #include <aclCommon/ArmComputeUtils.hpp> #include "armnn/Exceptions.hpp" +#include "ArmComputeUtils.hpp" #include <armnn/Descriptors.hpp> #include <fmt/format.h> @@ -293,28 +294,44 @@ arm_compute::PermutationVector BuildArmComputePermutationVector(const armnn::Per arm_compute::PermutationVector BuildArmComputeTransposeVector(const armnn::PermutationVector& perm) { - arm_compute::PermutationVector aclPerm; - std::map<unsigned int, unsigned int> permuteMappings; - for (unsigned int i = 0; i < perm.GetSize(); ++i) - { - permuteMappings[perm[i]] = i; - } + // As ArmNN indexes are left to right and ACL indexes are right to left, + // the permutation vector has to be reversed and then translated into ACL axis. + // i.e. {1, 0, 2, 3} --> {3, 2, 0, 1} --> {0, 1, 3, 2} + + // Below an example of how the ArmNN and ACL index format work: + // ArmNN Format: + // Input Shape {1, 10, 20, 30} + // Permutation Vector {1, 0, 2, 3} + // Output Shape {10, 1, 20, 30} + // dim "1" of input goes into index 0 of the output ([ 10, X, X, X]) + // dim "0" of input goes into index 1 of the output ([ 10, 1, X, X ]) + // dim "2" of input goes into index 2 of the output ([ 10, 1, 20, X ]) + // dim "3" of input goes into index 3 of the output ([ 10, 1, 20, 30 ]) + // ACL Format: + // Input Shape {30, 20, 10, 1} + // Permutation Vector {0, 1, 3, 2} + // Output Shape {30, 20, 1, 10} + // dim "0" of input goes into index 0 of the output ([ 30, X, X, X]) + // dim "1" of input goes into index 1 of the output ([ 30, 20, X, X ]) + // dim "3" of input goes into index 2 of the output ([ 30, 20, 1, X ]) + // dim "2" of input goes into index 3 of the output ([ 30, 20, 1, 10 ]) - std::vector<unsigned int> permuteVector; - for (unsigned int i = 0; i < perm.GetSize(); ++i) - { - permuteVector.push_back(permuteMappings.at(i)); - } + arm_compute::PermutationVector aclPerm; + auto rank = perm.GetSize(); - unsigned int start = 0; - while ((start < perm.GetSize()) && (start == permuteVector[start])) + // Reverse the order. i.e. {1, 0, 2, 3} --> {3, 2, 0, 1} + std::vector<unsigned int> reversedPerm; + reversedPerm.reserve(rank); + for (unsigned int i = rank; i > 0; --i) { - ++start; + reversedPerm.push_back(perm[i-1]); } - for (unsigned int i = start; i < perm.GetSize(); ++i) + // Translate from Arm NN axis to ACL axis. i.e. {3, 2, 0, 1} --> {0, 1, 3, 2} + for (unsigned int i = 0; i < rank; ++i) { - aclPerm.set(i - start, permuteVector[i] - start); + auto aclAxis = rank - 1 - reversedPerm[i]; + aclPerm.set(i, aclAxis); } return aclPerm; } |