aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/optimizations/DeleteBroadcastTo.hpp
blob: 9ea20907dffeb1942704213e5e17ff022d9825a6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
//
// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once

#include "Optimization.hpp"

namespace armnn
{
namespace optimizations
{
class DeleteBroadcastToImpl
{
public:
    /// Run for every BroadcastToLayer. Remove it if it is before an ElementWiseLayer.
    /// Since ElementWiseBinary uses a brodcastLoop, using a broadcastTo layer is
    /// not necessary so it will be deleted.
    void Run(Graph&, BroadcastToLayer& layer) const
    {
        if(layer.GetType() == LayerType::BroadcastTo)
        {
            Layer& next = layer.GetOutputSlot(0).GetConnection(0)->GetOwningLayer();
            if (next.GetType() == LayerType::ElementwiseBinary)
            {
                Layer& connectedLayer = layer.GetInputSlots()[0].GetConnectedOutputSlot()->GetOwningLayer();
                layer.GetOutputSlot().MoveAllConnections(connectedLayer.GetOutputSlot());
            }
        }
    }
protected:
    DeleteBroadcastToImpl() = default;
    ~DeleteBroadcastToImpl() = default;
};
using BroadcastToOptimizationLayer = OptimizeForType<BroadcastToLayer, DeleteBroadcastToImpl>;
}
}