aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/optimizations/PermuteAsReshape.hpp
diff options
context:
space:
mode:
authortelsoa01 <telmo.soares@arm.com>2018-03-09 14:13:49 +0000
committertelsoa01 <telmo.soares@arm.com>2018-03-09 14:13:49 +0000
commit4fcda0101ec3d110c1d6d7bee5c83416b645528a (patch)
treec9a70aeb2887006160c1b3d265c27efadb7bdbae /src/armnn/optimizations/PermuteAsReshape.hpp
downloadarmnn-4fcda0101ec3d110c1d6d7bee5c83416b645528a.tar.gz
Release 18.02
Change-Id: Id3c11dc5ee94ef664374a988fcc6901e9a232fa6
Diffstat (limited to 'src/armnn/optimizations/PermuteAsReshape.hpp')
-rw-r--r--src/armnn/optimizations/PermuteAsReshape.hpp70
1 files changed, 70 insertions, 0 deletions
diff --git a/src/armnn/optimizations/PermuteAsReshape.hpp b/src/armnn/optimizations/PermuteAsReshape.hpp
new file mode 100644
index 0000000000..a8e4c2df5e
--- /dev/null
+++ b/src/armnn/optimizations/PermuteAsReshape.hpp
@@ -0,0 +1,70 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+#pragma once
+
+#include "Optimization.hpp"
+
+namespace armnn
+{
+namespace optimizations
+{
+
+class PermuteAsReshapeImpl
+{
+public:
+ /// Run for every PermuteLayer. Replaces it with a ReshapeLayer if they are equivalent.
+ void Run(Graph& graph, PermuteLayer& permute) const
+ {
+ if (IsReshape(permute))
+ {
+ const TensorInfo& outInfo = permute.GetOutputHandler().GetTensorInfo();
+
+ const std::string name = std::string("as_reshape-") + permute.GetName();
+ const ReshapeDescriptor descriptor{outInfo.GetShape()};
+ // Insert so layers don't need to be re-sorted
+ auto reshape = graph.InsertNewLayer<ReshapeLayer>(permute.GetInputSlot(0), descriptor, name.c_str());
+ reshape->GetOutputHandler().SetTensorInfo(outInfo);
+
+ // Bypass permute. It will be deleted since it's left unconnected.
+ permute.GetOutputSlot().MoveAllConnections(reshape->GetOutputSlot());
+ }
+ }
+
+protected:
+ PermuteAsReshapeImpl() = default;
+ ~PermuteAsReshapeImpl() = default;
+
+private:
+ static bool IsReshape(const PermuteLayer& layer)
+ {
+ const TensorShape& outShape = layer.GetOutputHandler().GetTensorInfo().GetShape();
+ const PermutationVector& permutation = layer.GetPermutation();
+
+ const unsigned int numDimensions = permutation.GetSize();
+
+ unsigned int lastGtOne = 0;
+ while ((lastGtOne < numDimensions) && (outShape[(permutation[lastGtOne])] == 1U))
+ {
+ ++lastGtOne;
+ }
+
+ bool isReshape = true;
+ for (unsigned int i = lastGtOne + 1U; isReshape && (i < numDimensions); ++i)
+ {
+ if (outShape[permutation[i]] > 1U)
+ {
+ isReshape = permutation[lastGtOne] < permutation[i];
+ lastGtOne = i;
+ }
+ }
+
+ return isReshape;
+ }
+};
+
+using PermuteAsReshape = OptimizeForType<PermuteLayer, PermuteAsReshapeImpl>;
+
+} // namespace optimizations
+} // namespace armnn