diff options
Diffstat (limited to 'src/armnn/optimizations/TransposeAsReshape.hpp')
-rw-r--r-- | src/armnn/optimizations/TransposeAsReshape.hpp | 81 |
1 files changed, 81 insertions, 0 deletions
diff --git a/src/armnn/optimizations/TransposeAsReshape.hpp b/src/armnn/optimizations/TransposeAsReshape.hpp new file mode 100644 index 0000000000..4bb2f192f3 --- /dev/null +++ b/src/armnn/optimizations/TransposeAsReshape.hpp @@ -0,0 +1,81 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include "Optimization.hpp" + +namespace armnn +{ +namespace optimizations +{ + +class TransposeAsReshapeImpl +{ +public: + /// Run for every TransposeLayer. Replaces it with a ReshapeLayer if they are equivalent. + void Run(Graph& graph, TransposeLayer& transpose) const + { + if (IsReshape(transpose)) + { + const TensorInfo& outInfo = transpose.GetOutputHandler().GetTensorInfo(); + + const std::string name = std::string("as_reshape-") + transpose.GetName(); + const ReshapeDescriptor descriptor{outInfo.GetShape()}; + // Inserts NewLayer so layers don't need to be re-sorted. + auto reshape = graph.InsertNewLayer<ReshapeLayer>(transpose.GetInputSlot(0), descriptor, name.c_str()); + reshape->GetOutputHandler().SetTensorInfo(outInfo); + + // Bypass transpose. It will be deleted since it's left unconnected. + transpose.GetOutputSlot().MoveAllConnections(reshape->GetOutputSlot()); + } + } + +protected: + TransposeAsReshapeImpl() = default; + ~TransposeAsReshapeImpl() = default; + +private: + static bool IsReshape(const TransposeLayer& layer) + { + const TensorShape& outShape = layer.GetOutputHandler().GetTensorInfo().GetShape(); + const PermutationVector& permutation = layer.GetPermutation(); + + const unsigned int numDimensions = permutation.GetSize(); + std::map<unsigned int, unsigned int> permuteMappings; + for (unsigned int i = 0; i < permutation.GetSize(); ++i) + { + permuteMappings[permutation[i]] = i; + } + + std::vector<unsigned int> permuteVector; + for (unsigned int i = 0; i < permutation.GetSize(); ++i) + { + permuteVector.push_back(permuteMappings.at(i)); + } + + unsigned int lastGtOne = 0; + while ((lastGtOne < numDimensions) && (outShape[(permuteVector[lastGtOne])] == 1U)) + { + ++lastGtOne; + } + + bool isReshape = true; + for (unsigned int i = lastGtOne + 1U; isReshape && (i < numDimensions); ++i) + { + if (outShape[permuteVector[i]] > 1U) + { + isReshape = permuteVector[lastGtOne] < permuteVector[i]; + lastGtOne = i; + } + } + + return isReshape; + } +}; + +using TransposeAsReshape = OptimizeForType<TransposeLayer, TransposeAsReshapeImpl>; + +} // namespace optimizations +} // namespace armnn |