aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/SplitterLayer.cpp
blob: 2d469b0bc926406d342e5417147c38c9156217b6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
//
// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include "SplitterLayer.hpp"

#include "LayerCloneBase.hpp"

#include <armnn/TypesUtils.hpp>
#include <backendsCommon/WorkloadData.hpp>
#include <backendsCommon/WorkloadFactory.hpp>

namespace armnn
{

SplitterLayer::SplitterLayer(const ViewsDescriptor& param, const char* name)
    : LayerWithParameters(1, param.GetNumViews(), LayerType::Splitter, param, name)
{
}

std::unique_ptr<IWorkload> SplitterLayer::CreateWorkload(const IWorkloadFactory& factory) const
{
    SplitterQueueDescriptor descriptor;

    // Copies the window origins to the descriptor.
    for (unsigned int i = 0; i < m_Param.GetNumViews(); ++i)
    {
        descriptor.m_ViewOrigins.emplace_back(
            std::vector<unsigned int>(m_Param.GetViewOrigin(i), m_Param.GetViewOrigin(i) + m_Param.GetNumDimensions()));
    }

    return factory.CreateSplitter(descriptor, PrepInfoAndDesc(descriptor));
}

template<typename FactoryType>
void SplitterLayer::CreateTensors(const FactoryType& factory)
{
    //If sub tensors are supported than all the "splitter" need to do is to
    //set the outputs to be appropriate sub tensors of the input.
    bool useSubTensors = factory.SupportsSubTensors();

    if (useSubTensors)
    {
        const OutputSlot* slot = GetInputSlots()[0].GetConnectedOutputSlot();
        const OutputHandler& outputHandler = GetInputSlots()[0].GetConnectedOutputSlot()->GetOutputHandler();

        const TensorInfo& parentInfo = outputHandler.GetTensorInfo();

        ITensorHandle* inputData = outputHandler.GetData();

        std::vector<std::unique_ptr<ITensorHandle>> subTensors;

        //Creates the outputs as subtensors of the input.
        for (unsigned int i = 0; i < m_Param.GetNumViews(); ++i)
        {
            const TensorInfo& info = m_OutputHandlers[i].GetTensorInfo();

            OutputSlot& outSlot = GetOutputSlot(i);
            ITensorHandleFactory::FactoryId factoryId = outSlot.GetTensorHandleFactoryId();
            auto CreateSubTensor = [&]()
            {
                // Make sure quantization parameters are in the same space
                if (parentInfo.IsTypeSpaceMatch(info) &&
                    factoryId == slot->GetTensorHandleFactoryId())
                {
                    return factory.CreateSubTensorHandle(*inputData,
                                                         info.GetShape(),
                                                         this->m_Param.GetViewOrigin(i));
                }
                return std::unique_ptr<ITensorHandle>();
            };

            auto subTensor = CreateSubTensor();
            if (!subTensor)
            {
                useSubTensors = false;
                break; //Failed to create a valid sub-tensor, so stop trying with the rest of the views.
            }
            subTensors.push_back(std::move(subTensor));
        }

        if (useSubTensors)
        {
            unsigned int i = 0;
            for (auto& subTensor : subTensors)
            {
                m_OutputHandlers[i].SetData(std::move(subTensor));
                ++i;
            }

        }
    }

    if (!useSubTensors)
    {
        for (unsigned int i = 0; i < m_Param.GetNumViews(); ++i)
        {
            m_OutputHandlers[i].CreateTensorHandles(factory);
        }
    }
}

void SplitterLayer::CreateTensorHandles(const TensorHandleFactoryRegistry& registry,
                                        const IWorkloadFactory& workloadFactory,
                                        const bool IsMemoryManaged)
{
    IgnoreUnused(IsMemoryManaged);
    OutputSlot& slot = GetOutputSlot(0);
    ITensorHandleFactory::FactoryId factoryId = slot.GetTensorHandleFactoryId();

    if (factoryId == ITensorHandleFactory::LegacyFactoryId)
    {
        CreateTensors(workloadFactory);
    }
    else
    {
        ITensorHandleFactory* handleFactory = registry.GetFactory(factoryId);
        ARMNN_ASSERT(handleFactory);
        CreateTensors(*handleFactory);
    }
}

SplitterLayer* SplitterLayer::Clone(Graph& graph) const
{
    return CloneBase<SplitterLayer>(graph, m_Param, GetName());
}

std::vector<TensorShape> SplitterLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
{
    IgnoreUnused(inputShapes);
    ARMNN_ASSERT(inputShapes.size() ==  m_Param.GetNumViews());
    std::vector<TensorShape> outShapes;
    //Output shapes must match View shapes.
    for (unsigned int viewIdx = 0; viewIdx < m_Param.GetNumViews(); viewIdx++)
    {
        const uint32_t* sizes = m_Param.GetViewSizes(viewIdx);
        outShapes.push_back(TensorShape(m_Param.GetNumDimensions(), sizes));
    }
    return outShapes;
}

void SplitterLayer::ValidateTensorShapesFromInputs()
{
    std::for_each(BeginOutputSlots(), EndOutputSlots(), [&](OutputSlot& outputSlot)
    {
        VerifyShapeInferenceType(outputSlot.GetTensorInfo().GetShape(), m_ShapeInferenceMethod);
    });

    std::vector<TensorShape> views;
    for (unsigned int viewIdx = 0; viewIdx < m_Param.GetNumViews(); viewIdx++)
    {
        const uint32_t* sizes = m_Param.GetViewSizes(viewIdx);
        views.push_back(TensorShape(m_Param.GetNumDimensions(), sizes));
    }

    auto inferredShapes = InferOutputShapes(views);

    ARMNN_ASSERT(inferredShapes.size() == m_Param.GetNumViews());

    for (unsigned int viewIdx = 0; viewIdx < m_Param.GetNumViews(); viewIdx++)
    {
        ValidateAndCopyShape(GetOutputSlot(viewIdx).GetTensorInfo().GetShape(),
                             inferredShapes[viewIdx],
                             m_ShapeInferenceMethod,
                             "SplitterLayer",
                             viewIdx);
    }
}

void SplitterLayer::Accept(ILayerVisitor& visitor) const
{
    visitor.VisitSplitterLayer(this, GetParameters(), GetName());
}

} // namespace armnn