From 490b7becb8029ead26423b0d62e631a929e55d6c Mon Sep 17 00:00:00 2001 From: Mike Kelly Date: Tue, 3 Mar 2020 12:39:09 +0000 Subject: IVGCVSW-4375 Add support for Transpose to optimizations * Changed some existing Permutation specific optimizations to also support Transpose * Added MoveTransposeUp optimization * Added TransposeAsReshape optimization * Added tests for Transpose optimizations * Added missing layer tests for Transpose Signed-off-by: Mike Kelly Change-Id: I20d099b284861402ae94aaa5dbf34907327a485f --- src/armnn/optimizations/TransposeAsReshape.hpp | 81 ++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 src/armnn/optimizations/TransposeAsReshape.hpp (limited to 'src/armnn/optimizations/TransposeAsReshape.hpp') diff --git a/src/armnn/optimizations/TransposeAsReshape.hpp b/src/armnn/optimizations/TransposeAsReshape.hpp new file mode 100644 index 0000000000..4bb2f192f3 --- /dev/null +++ b/src/armnn/optimizations/TransposeAsReshape.hpp @@ -0,0 +1,81 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include "Optimization.hpp" + +namespace armnn +{ +namespace optimizations +{ + +class TransposeAsReshapeImpl +{ +public: + /// Run for every TransposeLayer. Replaces it with a ReshapeLayer if they are equivalent. + void Run(Graph& graph, TransposeLayer& transpose) const + { + if (IsReshape(transpose)) + { + const TensorInfo& outInfo = transpose.GetOutputHandler().GetTensorInfo(); + + const std::string name = std::string("as_reshape-") + transpose.GetName(); + const ReshapeDescriptor descriptor{outInfo.GetShape()}; + // Inserts NewLayer so layers don't need to be re-sorted. + auto reshape = graph.InsertNewLayer(transpose.GetInputSlot(0), descriptor, name.c_str()); + reshape->GetOutputHandler().SetTensorInfo(outInfo); + + // Bypass transpose. It will be deleted since it's left unconnected. + transpose.GetOutputSlot().MoveAllConnections(reshape->GetOutputSlot()); + } + } + +protected: + TransposeAsReshapeImpl() = default; + ~TransposeAsReshapeImpl() = default; + +private: + static bool IsReshape(const TransposeLayer& layer) + { + const TensorShape& outShape = layer.GetOutputHandler().GetTensorInfo().GetShape(); + const PermutationVector& permutation = layer.GetPermutation(); + + const unsigned int numDimensions = permutation.GetSize(); + std::map permuteMappings; + for (unsigned int i = 0; i < permutation.GetSize(); ++i) + { + permuteMappings[permutation[i]] = i; + } + + std::vector permuteVector; + for (unsigned int i = 0; i < permutation.GetSize(); ++i) + { + permuteVector.push_back(permuteMappings.at(i)); + } + + unsigned int lastGtOne = 0; + while ((lastGtOne < numDimensions) && (outShape[(permuteVector[lastGtOne])] == 1U)) + { + ++lastGtOne; + } + + bool isReshape = true; + for (unsigned int i = lastGtOne + 1U; isReshape && (i < numDimensions); ++i) + { + if (outShape[permuteVector[i]] > 1U) + { + isReshape = permuteVector[lastGtOne] < permuteVector[i]; + lastGtOne = i; + } + } + + return isReshape; + } +}; + +using TransposeAsReshape = OptimizeForType; + +} // namespace optimizations +} // namespace armnn -- cgit v1.2.1