blob: 9ea20907dffeb1942704213e5e17ff022d9825a6 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
|
//
// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
#include "Optimization.hpp"
namespace armnn
{
namespace optimizations
{
class DeleteBroadcastToImpl
{
public:
/// Run for every BroadcastToLayer. Remove it if it is before an ElementWiseLayer.
/// Since ElementWiseBinary uses a brodcastLoop, using a broadcastTo layer is
/// not necessary so it will be deleted.
void Run(Graph&, BroadcastToLayer& layer) const
{
if(layer.GetType() == LayerType::BroadcastTo)
{
Layer& next = layer.GetOutputSlot(0).GetConnection(0)->GetOwningLayer();
if (next.GetType() == LayerType::ElementwiseBinary)
{
Layer& connectedLayer = layer.GetInputSlots()[0].GetConnectedOutputSlot()->GetOwningLayer();
layer.GetOutputSlot().MoveAllConnections(connectedLayer.GetOutputSlot());
}
}
}
protected:
DeleteBroadcastToImpl() = default;
~DeleteBroadcastToImpl() = default;
};
using BroadcastToOptimizationLayer = OptimizeForType<BroadcastToLayer, DeleteBroadcastToImpl>;
}
}
|