diff options
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r-- | ethosu/vela/tensor.py | 33 |
1 files changed, 30 insertions, 3 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index c1443b3b..69618d2c 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -31,10 +31,11 @@ from uuid import UUID import numpy as np -from . import errors # Import this way due to cyclic imports from . import numeric_util from .data_type import BaseType from .data_type import DataType +from .errors import UnsupportedFeatureError +from .errors import VelaError from .ethos_u55_regs.ethos_u55_regs import resampling_mode from .operation import Op from .operation import Operation @@ -630,7 +631,7 @@ class Tensor: if end_coord[2] > crossing_x: addresses[1] = self.address_for_coordinate([start_coord[0], start_coord[1], crossing_x, start_coord[3]]) - raise errors.UnsupportedFeatureError("Striping in vertical direction is not supported") + raise UnsupportedFeatureError("Striping in vertical direction is not supported") if end_coord[1] > crossing_y: addresses[2] = self.address_for_coordinate([start_coord[0], crossing_y, start_coord[2], start_coord[3]]) if end_coord[1] > crossing_y and end_coord[2] > crossing_x: @@ -737,7 +738,7 @@ class Tensor: if index < len(self.weight_compressed_offsets) - 1: # There are no half-way points in the weights if (depth % brick_depth) != 0: - raise errors.UnsupportedFeatureError("Offset into weights must be aligned to a brick") + raise UnsupportedFeatureError("Offset into weights must be aligned to a brick") return index @@ -851,6 +852,32 @@ class Tensor: __repr__ = __str__ + def error(self, msg): + """ + Raises a VelaError exception for errors encountered when parsing a Tensor + + :param self: Tensor object that resulted in the error + :param msg: str object that contains a description of the specific error encountered + """ + + def _print_operators(ops): + lines = [] + for idx, op in enumerate(ops): + op_type = getattr(op, "type", "Not an Operation") + op_id = getattr(op, "op_index", "-") + lines.append(f" {idx} = {op_type} ({op_id})") + return lines + + lines = [f"Invalid {self.name} tensor. {msg}"] + + lines += [" Driving operators:"] + lines += _print_operators(self.ops) + + lines += [" Consuming operators:"] + lines += _print_operators(self.consumer_list) + + raise VelaError("\n".join(lines)) + def check_quantized_tens_scaling_equal(tens_a: Tensor, tens_b: Tensor) -> bool: # checks that the scaling of two quantized tensors are equal |