aboutsummaryrefslogtreecommitdiff
path: root/src/armnn
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn')
-rw-r--r--src/armnn/layers/BroadcastToLayer.cpp1
-rw-r--r--src/armnn/optimizations/DeleteBroadcastTo.hpp3
2 files changed, 3 insertions, 1 deletions
diff --git a/src/armnn/layers/BroadcastToLayer.cpp b/src/armnn/layers/BroadcastToLayer.cpp
index 252aa46de0..26a90eefed 100644
--- a/src/armnn/layers/BroadcastToLayer.cpp
+++ b/src/armnn/layers/BroadcastToLayer.cpp
@@ -6,7 +6,6 @@
#include "BroadcastToLayer.hpp"
#include "LayerCloneBase.hpp"
-
#include <armnn/TypesUtils.hpp>
#include <armnn/backends/WorkloadData.hpp>
#include <armnn/backends/WorkloadFactory.hpp>
diff --git a/src/armnn/optimizations/DeleteBroadcastTo.hpp b/src/armnn/optimizations/DeleteBroadcastTo.hpp
index 9ea20907df..38396c1a9c 100644
--- a/src/armnn/optimizations/DeleteBroadcastTo.hpp
+++ b/src/armnn/optimizations/DeleteBroadcastTo.hpp
@@ -20,11 +20,14 @@ public:
{
if(layer.GetType() == LayerType::BroadcastTo)
{
+ TensorInfo info = layer.GetOutputSlot(0).GetTensorInfo();
Layer& next = layer.GetOutputSlot(0).GetConnection(0)->GetOwningLayer();
if (next.GetType() == LayerType::ElementwiseBinary)
{
Layer& connectedLayer = layer.GetInputSlots()[0].GetConnectedOutputSlot()->GetOwningLayer();
+ auto tensorInfo = connectedLayer.GetOutputSlot().GetTensorInfo();
layer.GetOutputSlot().MoveAllConnections(connectedLayer.GetOutputSlot());
+ connectedLayer.GetOutputSlot().GetOutputHandler().SetTensorInfo(tensorInfo);
}
}
}