aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/LayerTests.cpp
blob: 3e9184e3be266f7695e4c647e29c1cca4ea46a8a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
//
// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include <Network.hpp>
#include <doctest/doctest.h>
#include <armnn/utility/PolymorphicDowncast.hpp>

namespace
{

TEST_SUITE("Layer")
{

TEST_CASE("InputSlotGetTensorInfo")
{
    armnn::NetworkImpl net;
    armnn::IConnectableLayer* add = net.AddElementwiseBinaryLayer(armnn::BinaryOperation::Add);
    armnn::IConnectableLayer* out = net.AddOutputLayer(0);

    armnn::Layer* addlayer = armnn::PolymorphicDowncast<armnn::Layer*>(add);
    armnn::Layer* outlayer = armnn::PolymorphicDowncast<armnn::Layer*>(out);

    auto outTensorInfo = armnn::TensorInfo({1,2,2,1}, armnn::DataType::Float32);
    addlayer->GetOutputSlot(0).Connect(outlayer->GetInputSlot(0));
    CHECK_FALSE(outlayer->GetInputSlot(0).IsTensorInfoSet());

    addlayer->GetOutputSlot(0).SetTensorInfo(outTensorInfo);
    auto testTensorInfo = outlayer->GetInputSlot(0).GetTensorInfo();

    CHECK_EQ(outTensorInfo, testTensorInfo);
    CHECK(outlayer->GetInputSlot(0).IsTensorInfoSet());
    CHECK_FALSE(outlayer->GetInputSlot(0).IsTensorInfoOverridden());

    auto overRiddenTensorInfo = armnn::TensorInfo({2,2}, armnn::DataType::Float32);
    outlayer->GetInputSlot(0).SetTensorInfo(overRiddenTensorInfo);
    testTensorInfo = outlayer->GetInputSlot(0).GetTensorInfo();

    // Confirm that inputslot TensorInfo is changed
    CHECK_EQ(overRiddenTensorInfo, testTensorInfo);
    // Confirm that outputslot TensorInfo is unchanged
    CHECK_EQ(outTensorInfo, outlayer->GetInputSlot(0).GetConnection()->GetTensorInfo());

    CHECK(outlayer->GetInputSlot(0).IsTensorInfoOverridden());
}

}

}