aboutsummaryrefslogtreecommitdiff
path: root/include/armnn/utility/TransformIterator.hpp
blob: 66fee8715dbc812c7e29c15e88e984db185e3542 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
//
// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once

#include <iterator>

namespace armnn
{

template<typename Function,
        typename Iterator,
        typename Category = typename std::iterator_traits<Iterator>::iterator_category,
        typename T = typename std::iterator_traits<Iterator>::value_type,
        typename Distance = typename std::iterator_traits<Iterator>::difference_type,
        typename Pointer = typename std::iterator_traits<Iterator>::pointer,
        typename Reference =
        typename std::result_of<const Function(typename std::iterator_traits<Iterator>::reference)>::type
>
class TransformIterator : public std::iterator<Category, T, Distance, Pointer, Reference>
{

public:

    TransformIterator() = default;
    TransformIterator(TransformIterator const& transformIterator) = default;
    TransformIterator(TransformIterator&& transformIterator) = default;

    TransformIterator(Iterator& it, Function fn) : m_it(it), m_fn(fn) {}
    TransformIterator(Iterator&& it, Function fn) : m_it(it), m_fn(fn) {}

    ~TransformIterator() = default;

    TransformIterator operator=(TransformIterator const& transformIterator)
    {
        return { transformIterator.m_it, transformIterator.m_fn };
    }

    TransformIterator operator=(TransformIterator&& transformIterator)
    {
        return { transformIterator.m_it, transformIterator.m_fn };
    }

    TransformIterator operator++() {++m_it; return *this;}
    TransformIterator operator--() {--m_it; return *this;}

    TransformIterator operator++() const {++m_it; return *this;}
    TransformIterator operator--() const {--m_it; return *this;}

    TransformIterator operator++(int n) const {m_it += n; return *this;}
    TransformIterator operator--(int n) const {m_it -= n; return *this;}

    TransformIterator operator[](Distance n) const {m_it[n]; return  *this;};

    Distance operator-(const TransformIterator& other) {return m_it - other.m_it;}

    TransformIterator operator-(const Distance n) {return {m_it - n, m_fn};}
    TransformIterator operator+(const Distance n) {return {m_it + n, m_fn};}

    bool operator>(const TransformIterator& rhs) const {return m_it > rhs.m_it;}
    bool operator<(const TransformIterator& rhs) const {return m_it < rhs.m_it;}
    bool operator>=(const TransformIterator& rhs) const {return m_it >= rhs.m_it;}
    bool operator<=(const TransformIterator& rhs) const {return m_it <= rhs.m_it;}

    bool operator==(TransformIterator other) const {return (m_it == other.m_it);}
    bool operator!=(TransformIterator other) const {return !(m_it == other.m_it);}

    Reference operator*() const {return m_fn(*m_it);}

private:
    Iterator m_it;
    const Function m_fn;
};

template<typename Function, typename Iterator>
constexpr TransformIterator<Function, Iterator> MakeTransformIterator(Iterator i, Function f)
{
    return TransformIterator<Function, Iterator>(i, f);
}

}