diff options
Diffstat (limited to 'ethosu')
-rw-r--r-- | ethosu/vela/errors.py | 69 | ||||
-rw-r--r-- | ethosu/vela/mark_tensors.py | 3 | ||||
-rw-r--r-- | ethosu/vela/operation.py | 29 | ||||
-rw-r--r-- | ethosu/vela/tensor.py | 33 | ||||
-rw-r--r-- | ethosu/vela/tflite_reader.py | 3 |
5 files changed, 61 insertions, 76 deletions
diff --git a/ethosu/vela/errors.py b/ethosu/vela/errors.py index b241db8e..04468c90 100644 --- a/ethosu/vela/errors.py +++ b/ethosu/vela/errors.py @@ -15,8 +15,6 @@ # limitations under the License. # Description: # Defines custom exceptions. -from .operation import Operation -from .tensor import Tensor class VelaError(Exception): @@ -75,70 +73,3 @@ class AllocationError(VelaError): def __init__(self, msg): super().__init__(f"Allocation failed: {msg}") - - -def OperatorError(op, msg): - """ - Raises a VelaError exception for errors encountered when parsing an Operation - - :param op: Operation object that resulted in the error - :param msg: str object that contains a description of the specific error encountered - """ - - def _print_tensors(tensors): - lines = [] - for idx, tens in enumerate(tensors): - if isinstance(tens, Tensor): - tens_name = tens.name - else: - tens_name = "Not a Tensor" - lines.append(f" {idx} = {tens_name}") - return lines - - assert isinstance(op, Operation) - - if op.op_index is None: - lines = [f"Invalid {op.type} (name = {op.name}) operator in the internal representation. {msg}"] - else: - lines = [f"Invalid {op.type} (op_index = {op.op_index}) operator in the input network. {msg}"] - - lines += [" Input tensors:"] - lines += _print_tensors(op.inputs) - - lines += [" Output tensors:"] - lines += _print_tensors(op.outputs) - - raise VelaError("\n".join(lines)) - - -def TensorError(tens, msg): - """ - Raises a VelaError exception for errors encountered when parsing a Tensor - - :param tens: Tensor object that resulted in the error - :param msg: str object that contains a description of the specific error encountered - """ - - def _print_operators(ops): - lines = [] - for idx, op in enumerate(ops): - if isinstance(op, Operation): - op_type = op.type - op_id = f"({op.op_index})" - else: - op_type = "Not an Operation" - op_id = "" - lines.append(f" {idx} = {op_type} {op_id}") - return lines - - assert isinstance(tens, Tensor) - - lines = [f"Invalid {tens.name} tensor. {msg}"] - - lines += [" Driving operators:"] - lines += _print_operators(tens.ops) - - lines += [" Consuming operators:"] - lines += _print_operators(tens.consumer_list) - - raise VelaError("\n".join(lines)) diff --git a/ethosu/vela/mark_tensors.py b/ethosu/vela/mark_tensors.py index 723bd876..5a475841 100644 --- a/ethosu/vela/mark_tensors.py +++ b/ethosu/vela/mark_tensors.py @@ -15,7 +15,6 @@ # limitations under the License. # Description: # Mark purpose and select formats for Tensors. -from .errors import OperatorError from .operation import CustomType from .operation import Op from .rewrite_graph import visit_graph_post_order @@ -81,7 +80,7 @@ def rewrite_mark_tensor_purpose(op, arch): scratch_tensor.purpose = TensorPurpose.Scratch if scratch_tensor is None: - OperatorError(op, "Scratch tensor not found.") + op.error("Scratch tensor not found.") def mark_tensor_purpose(nng, arch, verbose_tensor_purpose=False): diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index afc02d41..30c32acc 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -24,6 +24,7 @@ from typing import List from typing import Optional from typing import TYPE_CHECKING +from .errors import VelaError from .numeric_util import full_shape if TYPE_CHECKING: @@ -668,3 +669,31 @@ class Operation: if self.forced_output_quantization is not None: return self.forced_output_quantization return self.ofm.quantization + + def error(self, msg): + """ + Raises a VelaError exception for errors encountered when parsing an Operation + + :param self: Operation object that resulted in the error + :param msg: str object that contains a description of the specific error encountered + """ + + def _print_tensors(tensors): + lines = [] + for idx, tens in enumerate(tensors): + tens_name = getattr(tens, "name", "Not a Tensor") + lines.append(f" {idx} = {tens_name}") + return lines + + if self.op_index is None: + lines = [f"Invalid {self.type} (name = {self.name}) operator in the internal representation. {msg}"] + else: + lines = [f"Invalid {self.type} (op_index = {self.op_index}) operator in the input network. {msg}"] + + lines += [" Input tensors:"] + lines += _print_tensors(self.inputs) + + lines += [" Output tensors:"] + lines += _print_tensors(self.outputs) + + raise VelaError("\n".join(lines)) diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index c1443b3b..69618d2c 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -31,10 +31,11 @@ from uuid import UUID import numpy as np -from . import errors # Import this way due to cyclic imports from . import numeric_util from .data_type import BaseType from .data_type import DataType +from .errors import UnsupportedFeatureError +from .errors import VelaError from .ethos_u55_regs.ethos_u55_regs import resampling_mode from .operation import Op from .operation import Operation @@ -630,7 +631,7 @@ class Tensor: if end_coord[2] > crossing_x: addresses[1] = self.address_for_coordinate([start_coord[0], start_coord[1], crossing_x, start_coord[3]]) - raise errors.UnsupportedFeatureError("Striping in vertical direction is not supported") + raise UnsupportedFeatureError("Striping in vertical direction is not supported") if end_coord[1] > crossing_y: addresses[2] = self.address_for_coordinate([start_coord[0], crossing_y, start_coord[2], start_coord[3]]) if end_coord[1] > crossing_y and end_coord[2] > crossing_x: @@ -737,7 +738,7 @@ class Tensor: if index < len(self.weight_compressed_offsets) - 1: # There are no half-way points in the weights if (depth % brick_depth) != 0: - raise errors.UnsupportedFeatureError("Offset into weights must be aligned to a brick") + raise UnsupportedFeatureError("Offset into weights must be aligned to a brick") return index @@ -851,6 +852,32 @@ class Tensor: __repr__ = __str__ + def error(self, msg): + """ + Raises a VelaError exception for errors encountered when parsing a Tensor + + :param self: Tensor object that resulted in the error + :param msg: str object that contains a description of the specific error encountered + """ + + def _print_operators(ops): + lines = [] + for idx, op in enumerate(ops): + op_type = getattr(op, "type", "Not an Operation") + op_id = getattr(op, "op_index", "-") + lines.append(f" {idx} = {op_type} ({op_id})") + return lines + + lines = [f"Invalid {self.name} tensor. {msg}"] + + lines += [" Driving operators:"] + lines += _print_operators(self.ops) + + lines += [" Consuming operators:"] + lines += _print_operators(self.consumer_list) + + raise VelaError("\n".join(lines)) + def check_quantized_tens_scaling_equal(tens_a: Tensor, tens_b: Tensor) -> bool: # checks that the scaling of two quantized tensors are equal diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py index eff702b3..21ff8873 100644 --- a/ethosu/vela/tflite_reader.py +++ b/ethosu/vela/tflite_reader.py @@ -20,7 +20,6 @@ import os.path import numpy as np from .errors import InputFileError -from .errors import TensorError from .nn_graph import Graph from .nn_graph import Subgraph from .operation import create_activation_function @@ -77,7 +76,7 @@ class TFLiteSubgraph: # Fix up tensors without operations. Generate either Placeholder or Constant ops for tens in self.inputs: if tens.ops != []: - TensorError(tens, "This subgraph input tensor has unexpected driving operators.") + tens.error("This subgraph input tensor has unexpected driving operators.") op = Operation(Op.Placeholder, tens.name) op.set_output_tensor(tens) |