aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Kelly <mike.kelly@arm.com>2019-11-29 16:35:55 +0000
committermike.kelly <mike.kelly@arm.com>2019-11-29 17:28:21 +0000
commit88d5f9f1615fa956464b8932b574d85c37cec937 (patch)
treed2d9ed3079799b1e6418936430f1e4de21c56a29
parentdf31cfe29f9dccc4c2055a1d2a97de644b07d522 (diff)
downloadarmnn-88d5f9f1615fa956464b8932b574d85c37cec937.tar.gz
MLCE-143 Fixed driver crash during CTS tests
* Only apply the Optimization when the base ReshapeLayer is connected to the child ReshapeLayer and no other Layer. Signed-off-by: Mike Kelly <mike.kelly@arm.com> Change-Id: Iccd676d657f9e7c829813f1bec9c82db8745d069
-rw-r--r--src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp10
1 files changed, 9 insertions, 1 deletions
diff --git a/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp b/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp
index f2dd7d23ec..53d4a3c4fd 100644
--- a/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp
+++ b/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp
@@ -14,7 +14,7 @@ namespace optimizations
class OptimizeConsecutiveReshapesImpl
{
public:
- /// Run for every connection between a base RashapeLayer and a child ReshapeLayer.
+ /// Run for every connection between a base ReshapeLayer and a child ReshapeLayer.
/// Inserts an equivalent ReshapeLayer that bypasses both for that connection.
void Run(Graph& graph, InputSlot& connection) const
{
@@ -29,12 +29,20 @@ public:
const TensorInfo& inInfo = parentOut->GetTensorInfo();
const TensorInfo& outInfo = child.GetOutputHandler().GetTensorInfo();
+ // This Optimization is only appropriate when the base ReshapeLayer is connected to the child ReshapeLayer
+ // and no other Layer.
+ if (base.GetOutputSlot(0).GetNumConnections() > 1)
+ {
+ return;
+ }
+
if (inInfo.GetShape() != outInfo.GetShape())
{
// Inserts equivalent reshape before base layer.
const std::string name = std::string("merged-") + base.GetName() + std::string("-with-") + child.GetName();
const ReshapeDescriptor descriptor{outInfo.GetShape()};
auto& newReshape = *graph.InsertNewLayer<ReshapeLayer>(base.GetInputSlot(0), descriptor, name.c_str());
+
// Sets tensor info for new layer.
newReshape.GetOutputHandler().SetTensorInfo(outInfo);
// Parent is now the new layer.