diff options
author | Mike Kelly <mike.kelly@arm.com> | 2019-12-11 18:32:43 +0000 |
---|---|---|
committer | mike.kelly <mike.kelly@arm.com> | 2019-12-11 18:53:31 +0000 |
commit | c1d07143c3d01df58dd5f0e4a10b38b7bd3565d4 (patch) | |
tree | 748100d2a7ca542d3c85a1d1400b6eace2eccd4e /src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp | |
parent | 356bfec771858ed435874b525fd88da505380103 (diff) | |
download | armnn-c1d07143c3d01df58dd5f0e4a10b38b7bd3565d4.tar.gz |
MLCE-143 Fixed driver crash during CTS tests
MLCE-144 Fix cts MAX_POOL_2D_V1_0 tests
* Only apply the Optimization when the base ReshapeLayer is connected to
the child ReshapeLayer and no other Layer.
Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: Id1215e8b1c06d7bdb77905fec9649a8ec26436f0
Diffstat (limited to 'src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp')
-rw-r--r-- | src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp | 10 |
1 files changed, 9 insertions, 1 deletions
diff --git a/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp b/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp index 5047d5d678..e2d4a2dcc3 100644 --- a/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp +++ b/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp @@ -14,7 +14,7 @@ namespace optimizations class OptimizeConsecutiveReshapesImpl { public: - /// Run for every connection between a base RashapeLayer and a child ReshapeLayer. + /// Run for every connection between a base ReshapeLayer and a child ReshapeLayer. /// Inserts an equivalent ReshapeLayer that bypasses both for that connection. void Run(Graph& graph, InputSlot& connection) const { @@ -29,12 +29,20 @@ public: const TensorInfo& inInfo = parentOut->GetTensorInfo(); const TensorInfo& outInfo = child.GetOutputHandler().GetTensorInfo(); + // This Optimization is only appropriate when the base ReshapeLayer is connected to the child ReshapeLayer + // and no other Layer. + if (base.GetOutputSlot(0).GetNumConnections() > 1) + { + return; + } + if (inInfo.GetShape() != outInfo.GetShape()) { // Inserts equivalent reshape before base layer. const std::string name = std::string("merged-") + base.GetName() + std::string("-with-") + child.GetName(); const ReshapeDescriptor descriptor{outInfo.GetShape()}; auto& newReshape = *graph.InsertNewLayer<ReshapeLayer>(base.GetInputSlot(0), descriptor, name.c_str()); + // Sets tensor info for new layer. newReshape.GetOutputHandler().SetTensorInfo(outInfo); // Reconnects base with original parent. |