aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils/VerificationHelpers.cpp
blob: 243d22e4449d5599fb9dc161c6ceed9c30174f73 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include "VerificationHelpers.hpp"
#include <boost/format.hpp>
#include <armnn/Exceptions.hpp>

using namespace armnn;

namespace armnnUtils
{

void CheckValidSize(std::initializer_list<size_t> validInputCounts,
                    size_t actualValue,
                    const char* validExpr,
                    const char* actualExpr,
                    const CheckLocation& location)
{
    bool isValid = std::any_of(validInputCounts.begin(),
                               validInputCounts.end(),
                               [&actualValue](size_t x) { return x == actualValue; } );
    if (!isValid)
    {
        throw ParseException(
            boost::str(
                boost::format("%1% = %2% is not valid, not in {%3%}. %4%") %
                              actualExpr %
                              actualValue %
                              validExpr %
                              location.AsString()));
    }
}

uint32_t NonNegative(const char* expr,
                     int32_t value,
                     const CheckLocation& location)
{
    if (value < 0)
    {
        throw ParseException(
            boost::str(
                boost::format("'%1%' must be non-negative, received: %2% at %3%") %
                              expr %
                              value %
                              location.AsString() ));
    }
    else
    {
        return static_cast<uint32_t>(value);
    }
}

int32_t VerifyInt32(const char* expr,
                     int64_t value,
                     const armnn::CheckLocation& location)
{
    if (value < std::numeric_limits<int>::min()  || value > std::numeric_limits<int>::max())
    {
        throw ParseException(
            boost::str(
                boost::format("'%1%' must should fit into a int32 (ArmNN don't support int64), received: %2% at %3%") %
                              expr %
                              value %
                              location.AsString() ));
    }
    else
    {
        return static_cast<int32_t>(value);
    }
}

}// armnnUtils