aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/optimizations/AddBroadcastReshapeLayer.hpp
blob: b07ab54a3f80947b522dce19e5ff25894c50ef1f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
//
// Copyright © 2020-2021,2023-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once

#include "Optimization.hpp"

#include <armnn/backends/TensorHandle.hpp>
#include <armnn/utility/PolymorphicDowncast.hpp>

namespace armnn
{
namespace optimizations
{

static const std::set<armnn::LayerType> broadcastOps{ LayerType::Addition,       LayerType::Division,
                                                      LayerType::Maximum,        LayerType::Minimum,
                                                      LayerType::Multiplication, LayerType::Prelu,
                                                      LayerType::Subtraction,    LayerType::ElementwiseBinary,
                                                      LayerType::Comparison,     LayerType::LogicalBinary};

class AddBroadcastReshapeLayerImpl
{
public:
    /// Run for every ElementwiseBaseLayer. Add Broadcast reshape layer if the inputs shape are different.
    void Run(Graph& graph, Layer& layer) const
    {
        if (std::find(broadcastOps.begin(), broadcastOps.end(), layer.GetType()) != broadcastOps.end())
        {
            layer.GetInputSlot(0).GetConnectedOutputSlot()->IsTensorInfoSet();
            layer.GetInputSlot(1).GetConnectedOutputSlot()->IsTensorInfoSet();

            const TensorInfo& inputInfo0 = layer.GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo();
            const TensorInfo& inputInfo1 = layer.GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo();

            if (inputInfo0.GetNumDimensions() == inputInfo1.GetNumDimensions())
            {
                return;
            }

            unsigned int reshapeSlot = 1;
            TensorInfo reshapeInfo   = inputInfo1;
            TensorInfo inputInfo     = inputInfo0;

            if (inputInfo0.GetNumDimensions() < inputInfo1.GetNumDimensions())
            {
                reshapeSlot = 0;
                reshapeInfo = inputInfo0;
                inputInfo   = inputInfo1;
            }

            uint32_t numDimensions = inputInfo.GetNumDimensions();

            std::vector<unsigned> reshapedDim;
            for (unsigned int i = 0; i < reshapeInfo.GetNumDimensions(); ++i)
            {
                reshapedDim.push_back(reshapeInfo.GetShape()[i]);
            }

            std::vector<unsigned int> reshapedDimensions(numDimensions, 1);
            std::copy_backward(reshapedDim.begin(), reshapedDim.end(), reshapedDimensions.end());

            reshapeInfo.SetShape(armnn::TensorShape{ numDimensions, reshapedDimensions.data() });

            // If the parent layer is a Constant layer and it is only used once we can short circuit by just
            // changing the tensor info rather than adding a reshape layer.
            Layer& parentLayer = layer.GetInputSlot(reshapeSlot).GetConnectedOutputSlot()->GetOwningLayer();
            if ((parentLayer.GetType() == armnn::LayerType::Constant) &&
                (parentLayer.GetOutputSlot(0).GetNumConnections() == 1))
            {
                ConstantLayer& constantLayer = static_cast<ConstantLayer&>(parentLayer);

                constantLayer.m_LayerOutput = std::make_unique<ScopedTensorHandle>(
                    ConstTensor(reshapeInfo, constantLayer.m_LayerOutput.get()->GetConstTensor<void>()));
                constantLayer.GetOutputSlot().SetTensorInfo(reshapeInfo);
            }
            else
            {
                const std::string layerName = "Reshape_for:" + layer.GetNameStr() + "-" + std::to_string(reshapeSlot);
                const ReshapeDescriptor descriptor{ reshapeInfo.GetShape() };
                ReshapeLayer* reshapeLayer =
                    graph.InsertNewLayer<ReshapeLayer>(layer.GetInputSlot(reshapeSlot), descriptor, layerName.c_str());
                reshapeLayer->GetOutputSlot().SetTensorInfo(reshapeInfo);
            }
        }
    }

protected:
    AddBroadcastReshapeLayerImpl()  = default;
    ~AddBroadcastReshapeLayerImpl() = default;
};

using AddBroadcastReshapeLayer = OptimizeForType<Layer, AddBroadcastReshapeLayerImpl>;

}    // namespace optimizations
}    // namespace armnn