aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/GraphUtils.cpp
blob: bc6b562c9dd77ec3f40f45f8693d5ff042f1f296 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include "GraphUtils.hpp"

#include <armnn/utility/PolymorphicDowncast.hpp>

bool GraphHasNamedLayer(const armnn::Graph& graph, const std::string& name)
{
    for (auto&& layer : graph)
    {
        if (layer->GetName() == name)
        {
            return true;
        }
    }
    return false;
}

armnn::Layer* GetFirstLayerWithName(armnn::Graph& graph, const std::string& name)
{
    for (auto&& layer : graph)
    {
        if (layer->GetNameStr() == name)
        {
            return layer;
        }
    }
    return nullptr;
}

bool CheckNumberOfInputSlot(armnn::Layer* layer, unsigned int num)
{
    return layer->GetNumInputSlots() == num;
}

bool CheckNumberOfOutputSlot(armnn::Layer* layer, unsigned int num)
{
    return layer->GetNumOutputSlots() == num;
}

bool IsConnected(armnn::Layer* srcLayer, armnn::Layer* destLayer,
                 unsigned int srcSlot, unsigned int destSlot,
                 const armnn::TensorInfo& expectedTensorInfo)
{
    const armnn::IOutputSlot& outputSlot = srcLayer->GetOutputSlot(srcSlot);
    const armnn::TensorInfo& tensorInfo = outputSlot.GetTensorInfo();
    if (expectedTensorInfo != tensorInfo)
    {
        return false;
    }
    const unsigned int numConnections = outputSlot.GetNumConnections();
    for (unsigned int c = 0; c < numConnections; ++c)
    {
        auto inputSlot = armnn::PolymorphicDowncast<const armnn::InputSlot*>(outputSlot.GetConnection(c));
        if (inputSlot->GetOwningLayer().GetNameStr() == destLayer->GetNameStr() &&
            inputSlot->GetSlotIndex() == destSlot)
        {
            return true;
        }
    }
    return false;
}

/// Checks that first comes before second in the order.
bool CheckOrder(const armnn::Graph& graph, const armnn::Layer* first, const armnn::Layer* second)
{
    graph.Print();

    const auto& order = graph.TopologicalSort();

    auto firstPos = std::find(order.begin(), order.end(), first);
    auto secondPos = std::find(firstPos, order.end(), second);

    return (secondPos != order.end());
}