aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/errors.py
blob: 2c93fbc6441246fbf048dc15d99396f16d697c7e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the License); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Description:
# Defines custom exceptions.
import sys

from .operation import Operation
from .tensor import Tensor


class VelaError(Exception):
    """Base class for vela exceptions"""

    def __init__(self, data):
        self.data = data

    def __str__(self):
        return repr(self.data)


class InputFileError(VelaError):
    """Raised when reading the input file results in errors"""

    def __init__(self, file_name, msg):
        self.data = "Error reading input file {}: {}".format(file_name, msg)


class UnsupportedFeatureError(VelaError):
    """Raised when the input file uses non-supported features that cannot be handled"""

    def __init__(self, data):
        self.data = "The input file uses a feature that is currently not supported: {}".format(data)


class OptionError(VelaError):
    """Raised when an incorrect command line option is used"""

    def __init__(self, option, option_value, msg):
        self.data = "Incorrect argument to CLI option: {} {}: {}".format(option, option_value, msg)


def OperatorError(op, msg):
    """Called when parsing an operator results in errors"""

    assert isinstance(op, Operation)

    if op.op_index is None:
        data = "Invalid {} (name = {}) operator in the internal representation.".format(op.type, op.name)
    else:
        data = "Invalid {} (op_index = {}) operator in the input network.".format(op.type, op.op_index)

    data += " {}\n".format(msg)

    data += "   Input tensors:\n"
    for idx, tens in enumerate(op.inputs):
        if isinstance(tens, Tensor):
            tens_name = tens.name
        else:
            tens_name = "Not a Tensor"

        data += "      {} = {}\n".format(idx, tens_name)

    data += "   Output tensors:\n"
    for idx, tens in enumerate(op.outputs):
        if isinstance(tens, Tensor):
            tens_name = tens.name
        else:
            tens_name = "Not a Tensor"

        data += "      {} = {}\n".format(idx, tens_name)

    data = data[:-1]  # remove last newline

    print("Error: {}".format(data))
    sys.exit(1)


def TensorError(tens, msg):
    """Called when parsing a tensor results in errors"""

    assert isinstance(tens, Tensor)

    data = "Invalid {} tensor. {}\n".format(tens.name, msg)

    data += "   Driving operators:\n"
    for idx, op in enumerate(tens.ops):
        if isinstance(op, Operation):
            op_type = op.type
            op_id = op.op_index
        else:
            op_type = "Not an Operation"
            op_id = ""

        data += "      {} = {} ({})\n".format(idx, op_type, op_id)

    data += "   Consuming operators:\n"
    for idx, op in enumerate(tens.consumer_list):
        if isinstance(op, Operation):
            op_type = op.type
            op_id = op.op_index
        else:
            op_type = "Not an Operation"
            op_id = ""

        data += "      {} = {} ({})\n".format(idx, op_type, op_id)

    data = data[:-1]  # remove last newline

    print("Error: {}".format(data))
    sys.exit(1)