From 88d5f9f1615fa956464b8932b574d85c37cec937 Mon Sep 17 00:00:00 2001 From: Mike Kelly Date: Fri, 29 Nov 2019 16:35:55 +0000 Subject: MLCE-143 Fixed driver crash during CTS tests * Only apply the Optimization when the base ReshapeLayer is connected to the child ReshapeLayer and no other Layer. Signed-off-by: Mike Kelly Change-Id: Iccd676d657f9e7c829813f1bec9c82db8745d069 --- src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp b/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp index f2dd7d23ec..53d4a3c4fd 100644 --- a/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp +++ b/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp @@ -14,7 +14,7 @@ namespace optimizations class OptimizeConsecutiveReshapesImpl { public: - /// Run for every connection between a base RashapeLayer and a child ReshapeLayer. + /// Run for every connection between a base ReshapeLayer and a child ReshapeLayer. /// Inserts an equivalent ReshapeLayer that bypasses both for that connection. void Run(Graph& graph, InputSlot& connection) const { @@ -29,12 +29,20 @@ public: const TensorInfo& inInfo = parentOut->GetTensorInfo(); const TensorInfo& outInfo = child.GetOutputHandler().GetTensorInfo(); + // This Optimization is only appropriate when the base ReshapeLayer is connected to the child ReshapeLayer + // and no other Layer. + if (base.GetOutputSlot(0).GetNumConnections() > 1) + { + return; + } + if (inInfo.GetShape() != outInfo.GetShape()) { // Inserts equivalent reshape before base layer. const std::string name = std::string("merged-") + base.GetName() + std::string("-with-") + child.GetName(); const ReshapeDescriptor descriptor{outInfo.GetShape()}; auto& newReshape = *graph.InsertNewLayer(base.GetInputSlot(0), descriptor, name.c_str()); + // Sets tensor info for new layer. newReshape.GetOutputHandler().SetTensorInfo(outInfo); // Parent is now the new layer. -- cgit v1.2.1