diff options
Diffstat (limited to 'src/armnn/optimizations/PermuteAsReshape.hpp')
-rw-r--r-- | src/armnn/optimizations/PermuteAsReshape.hpp | 70 |
1 files changed, 70 insertions, 0 deletions
diff --git a/src/armnn/optimizations/PermuteAsReshape.hpp b/src/armnn/optimizations/PermuteAsReshape.hpp new file mode 100644 index 0000000000..a8e4c2df5e --- /dev/null +++ b/src/armnn/optimizations/PermuteAsReshape.hpp @@ -0,0 +1,70 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// +#pragma once + +#include "Optimization.hpp" + +namespace armnn +{ +namespace optimizations +{ + +class PermuteAsReshapeImpl +{ +public: + /// Run for every PermuteLayer. Replaces it with a ReshapeLayer if they are equivalent. + void Run(Graph& graph, PermuteLayer& permute) const + { + if (IsReshape(permute)) + { + const TensorInfo& outInfo = permute.GetOutputHandler().GetTensorInfo(); + + const std::string name = std::string("as_reshape-") + permute.GetName(); + const ReshapeDescriptor descriptor{outInfo.GetShape()}; + // Insert so layers don't need to be re-sorted + auto reshape = graph.InsertNewLayer<ReshapeLayer>(permute.GetInputSlot(0), descriptor, name.c_str()); + reshape->GetOutputHandler().SetTensorInfo(outInfo); + + // Bypass permute. It will be deleted since it's left unconnected. + permute.GetOutputSlot().MoveAllConnections(reshape->GetOutputSlot()); + } + } + +protected: + PermuteAsReshapeImpl() = default; + ~PermuteAsReshapeImpl() = default; + +private: + static bool IsReshape(const PermuteLayer& layer) + { + const TensorShape& outShape = layer.GetOutputHandler().GetTensorInfo().GetShape(); + const PermutationVector& permutation = layer.GetPermutation(); + + const unsigned int numDimensions = permutation.GetSize(); + + unsigned int lastGtOne = 0; + while ((lastGtOne < numDimensions) && (outShape[(permutation[lastGtOne])] == 1U)) + { + ++lastGtOne; + } + + bool isReshape = true; + for (unsigned int i = lastGtOne + 1U; isReshape && (i < numDimensions); ++i) + { + if (outShape[permutation[i]] > 1U) + { + isReshape = permutation[lastGtOne] < permutation[i]; + lastGtOne = i; + } + } + + return isReshape; + } +}; + +using PermuteAsReshape = OptimizeForType<PermuteLayer, PermuteAsReshapeImpl>; + +} // namespace optimizations +} // namespace armnn |