aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/test/AbsEndToEndTestImpl.hpp
blob: dd851e3960321cfd841da34709f64fa1f78370ef (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include "CommonTestUtils.hpp"

#include <QuantizeHelper.hpp>
#include <ResolveType.hpp>

#include <armnn/ArmNN.hpp>

namespace
{

armnn::INetworkPtr CreateAbsNetwork(const armnn::TensorInfo& tensorInfo)
{
    armnn::INetworkPtr network(armnn::INetwork::Create());

    armnn::IConnectableLayer* inputLayer  = network->AddInputLayer(0, "input");
    armnn::IConnectableLayer* absLayer    = network->AddAbsLayer("abs");
    armnn::IConnectableLayer* outputLayer = network->AddOutputLayer(0, "output");

    Connect(inputLayer, absLayer, tensorInfo, 0, 0);
    Connect(absLayer, outputLayer, tensorInfo, 0, 0);

    return network;
}

} // anonymous namespace

template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
void AbsEndToEnd(const std::vector<armnn::BackendId>& backends)
{
    using namespace armnn;

    const float   qScale  = IsQuantizedType<T>() ? 0.25f : 1.0f;
    const int32_t qOffset = IsQuantizedType<T>() ? 50    : 0;

    TensorInfo tensorInfo({ 1, 1, 2, 3 }, ArmnnType, qScale, qOffset);

    std::vector<float> inputData =
    {
       -1.f,  2.f, -3.f,
        4.f, -5.f,  6.f
    };

    std::vector<float> expectedOutputData =
    {
        1.f, 2.f, 3.f,
        4.f, 5.f, 6.f
    };

    // quantize data
    std::vector<T> qInputData          = armnnUtils::QuantizedVector<T>(inputData, qScale, qOffset);
    std::vector<T> qExpectedOutputData = armnnUtils::QuantizedVector<T>(expectedOutputData, qScale, qOffset);

    INetworkPtr network = CreateAbsNetwork(tensorInfo);

    EndToEndLayerTestImpl<ArmnnType, ArmnnType>(std::move(network),
                                                { { 0, qInputData } },
                                                { { 0, qExpectedOutputData } },
                                                backends);
}