aboutsummaryrefslogtreecommitdiff
path: root/src/backends/aclCommon/ArmComputeTensorUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/aclCommon/ArmComputeTensorUtils.cpp')
-rw-r--r--src/backends/aclCommon/ArmComputeTensorUtils.cpp51
1 files changed, 34 insertions, 17 deletions
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.cpp b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
index 517b11ced8..38c7f70da5 100644
--- a/src/backends/aclCommon/ArmComputeTensorUtils.cpp
+++ b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
@@ -1,11 +1,12 @@
//
-// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include <aclCommon/ArmComputeTensorUtils.hpp>
#include <aclCommon/ArmComputeUtils.hpp>
#include "armnn/Exceptions.hpp"
+#include "ArmComputeUtils.hpp"
#include <armnn/Descriptors.hpp>
#include <fmt/format.h>
@@ -293,28 +294,44 @@ arm_compute::PermutationVector BuildArmComputePermutationVector(const armnn::Per
arm_compute::PermutationVector BuildArmComputeTransposeVector(const armnn::PermutationVector& perm)
{
- arm_compute::PermutationVector aclPerm;
- std::map<unsigned int, unsigned int> permuteMappings;
- for (unsigned int i = 0; i < perm.GetSize(); ++i)
- {
- permuteMappings[perm[i]] = i;
- }
+ // As ArmNN indexes are left to right and ACL indexes are right to left,
+ // the permutation vector has to be reversed and then translated into ACL axis.
+ // i.e. {1, 0, 2, 3} --> {3, 2, 0, 1} --> {0, 1, 3, 2}
+
+ // Below an example of how the ArmNN and ACL index format work:
+ // ArmNN Format:
+ // Input Shape {1, 10, 20, 30}
+ // Permutation Vector {1, 0, 2, 3}
+ // Output Shape {10, 1, 20, 30}
+ // dim "1" of input goes into index 0 of the output ([ 10, X, X, X])
+ // dim "0" of input goes into index 1 of the output ([ 10, 1, X, X ])
+ // dim "2" of input goes into index 2 of the output ([ 10, 1, 20, X ])
+ // dim "3" of input goes into index 3 of the output ([ 10, 1, 20, 30 ])
+ // ACL Format:
+ // Input Shape {30, 20, 10, 1}
+ // Permutation Vector {0, 1, 3, 2}
+ // Output Shape {30, 20, 1, 10}
+ // dim "0" of input goes into index 0 of the output ([ 30, X, X, X])
+ // dim "1" of input goes into index 1 of the output ([ 30, 20, X, X ])
+ // dim "3" of input goes into index 2 of the output ([ 30, 20, 1, X ])
+ // dim "2" of input goes into index 3 of the output ([ 30, 20, 1, 10 ])
- std::vector<unsigned int> permuteVector;
- for (unsigned int i = 0; i < perm.GetSize(); ++i)
- {
- permuteVector.push_back(permuteMappings.at(i));
- }
+ arm_compute::PermutationVector aclPerm;
+ auto rank = perm.GetSize();
- unsigned int start = 0;
- while ((start < perm.GetSize()) && (start == permuteVector[start]))
+ // Reverse the order. i.e. {1, 0, 2, 3} --> {3, 2, 0, 1}
+ std::vector<unsigned int> reversedPerm;
+ reversedPerm.reserve(rank);
+ for (unsigned int i = rank; i > 0; --i)
{
- ++start;
+ reversedPerm.push_back(perm[i-1]);
}
- for (unsigned int i = start; i < perm.GetSize(); ++i)
+ // Translate from Arm NN axis to ACL axis. i.e. {3, 2, 0, 1} --> {0, 1, 3, 2}
+ for (unsigned int i = 0; i < rank; ++i)
{
- aclPerm.set(i - start, permuteVector[i] - start);
+ auto aclAxis = rank - 1 - reversedPerm[i];
+ aclPerm.set(i, aclAxis);
}
return aclPerm;
}