aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/optimizations/DeleteBroadcastTo.hpp
blob: 38396c1a9c27cc106b618530514becff9fba0742 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
//
// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once

#include "Optimization.hpp"

namespace armnn
{
namespace optimizations
{
class DeleteBroadcastToImpl
{
public:
    /// Run for every BroadcastToLayer. Remove it if it is before an ElementWiseLayer.
    /// Since ElementWiseBinary uses a brodcastLoop, using a broadcastTo layer is
    /// not necessary so it will be deleted.
    void Run(Graph&, BroadcastToLayer& layer) const
    {
        if(layer.GetType() == LayerType::BroadcastTo)
        {
            TensorInfo info = layer.GetOutputSlot(0).GetTensorInfo();
            Layer& next = layer.GetOutputSlot(0).GetConnection(0)->GetOwningLayer();
            if (next.GetType() == LayerType::ElementwiseBinary)
            {
                Layer& connectedLayer = layer.GetInputSlots()[0].GetConnectedOutputSlot()->GetOwningLayer();
                auto tensorInfo = connectedLayer.GetOutputSlot().GetTensorInfo();
                layer.GetOutputSlot().MoveAllConnections(connectedLayer.GetOutputSlot());
                connectedLayer.GetOutputSlot().GetOutputHandler().SetTensorInfo(tensorInfo);
            }
        }
    }
protected:
    DeleteBroadcastToImpl() = default;
    ~DeleteBroadcastToImpl() = default;
};
using BroadcastToOptimizationLayer = OptimizeForType<BroadcastToLayer, DeleteBroadcastToImpl>;
}
}