diff options
Diffstat (limited to 'src/armnn')
-rw-r--r-- | src/armnn/layers/BroadcastToLayer.cpp | 1 | ||||
-rw-r--r-- | src/armnn/optimizations/DeleteBroadcastTo.hpp | 3 |
2 files changed, 3 insertions, 1 deletions
diff --git a/src/armnn/layers/BroadcastToLayer.cpp b/src/armnn/layers/BroadcastToLayer.cpp index 252aa46de0..26a90eefed 100644 --- a/src/armnn/layers/BroadcastToLayer.cpp +++ b/src/armnn/layers/BroadcastToLayer.cpp @@ -6,7 +6,6 @@ #include "BroadcastToLayer.hpp" #include "LayerCloneBase.hpp" - #include <armnn/TypesUtils.hpp> #include <armnn/backends/WorkloadData.hpp> #include <armnn/backends/WorkloadFactory.hpp> diff --git a/src/armnn/optimizations/DeleteBroadcastTo.hpp b/src/armnn/optimizations/DeleteBroadcastTo.hpp index 9ea20907df..38396c1a9c 100644 --- a/src/armnn/optimizations/DeleteBroadcastTo.hpp +++ b/src/armnn/optimizations/DeleteBroadcastTo.hpp @@ -20,11 +20,14 @@ public: { if(layer.GetType() == LayerType::BroadcastTo) { + TensorInfo info = layer.GetOutputSlot(0).GetTensorInfo(); Layer& next = layer.GetOutputSlot(0).GetConnection(0)->GetOwningLayer(); if (next.GetType() == LayerType::ElementwiseBinary) { Layer& connectedLayer = layer.GetInputSlots()[0].GetConnectedOutputSlot()->GetOwningLayer(); + auto tensorInfo = connectedLayer.GetOutputSlot().GetTensorInfo(); layer.GetOutputSlot().MoveAllConnections(connectedLayer.GetOutputSlot()); + connectedLayer.GetOutputSlot().GetOutputHandler().SetTensorInfo(tensorInfo); } } } |