aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils/test/TransformIteratorTest.cpp
blob: 2151153913686a3888a11cc0f65b5c75d60946dc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
//
// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include <armnn/utility/TransformIterator.hpp>

#include <doctest/doctest.h>
#include <vector>
#include <iostream>

using namespace armnn;

TEST_SUITE("TransformIteratorSuite")
{
namespace
{

static int square(const int val)
{
    return val * val;
}

static std::string concat(const std::string val)
{
    return val + "a";
}

TEST_CASE("TransformIteratorTest")
{
    struct WrapperTestClass
    {
        TransformIterator<decltype(&square), std::vector<int>::const_iterator> begin() const
        {
            return { m_Vec.begin(), &square };
        }

        TransformIterator<decltype(&square), std::vector<int>::const_iterator>  end() const
        {
            return { m_Vec.end(), &square };
        }

        const std::vector<int> m_Vec{1, 2, 3, 4, 5};
    };

    struct WrapperStringClass
    {
        TransformIterator<decltype(&concat), std::vector<std::string>::const_iterator> begin() const
        {
            return { m_Vec.begin(), &concat };
        }

        TransformIterator<decltype(&concat), std::vector<std::string>::const_iterator>  end() const
        {
            return { m_Vec.end(), &concat };
        }

        const std::vector<std::string> m_Vec{"a", "b", "c"};
    };

    WrapperStringClass wrapperStringClass;
    WrapperTestClass wrapperTestClass;
    int i = 1;

    for(auto val : wrapperStringClass)
    {
        CHECK(val != "e");
        i++;
    }

    i = 1;
    for(auto val : wrapperTestClass)
    {
        CHECK(val == square(i));
        i++;
    }

    i = 1;
    // Check original vector is unchanged
    for(auto val : wrapperTestClass.m_Vec)
    {
        CHECK(val == i);
        i++;
    }

    std::vector<int> originalVec{1, 2, 3, 4, 5};

    auto transformBegin = MakeTransformIterator(originalVec.begin(), &square);
    auto transformEnd = MakeTransformIterator(originalVec.end(), &square);

    std::vector<int> transformedVec(transformBegin, transformEnd);

    i = 1;
    for(auto val : transformedVec)
    {
        CHECK(val == square(i));
        i++;
    }
}

}

}