aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tensor.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r--ethosu/vela/tensor.py33
1 files changed, 30 insertions, 3 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index c1443b3b..69618d2c 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -31,10 +31,11 @@ from uuid import UUID
import numpy as np
-from . import errors # Import this way due to cyclic imports
from . import numeric_util
from .data_type import BaseType
from .data_type import DataType
+from .errors import UnsupportedFeatureError
+from .errors import VelaError
from .ethos_u55_regs.ethos_u55_regs import resampling_mode
from .operation import Op
from .operation import Operation
@@ -630,7 +631,7 @@ class Tensor:
if end_coord[2] > crossing_x:
addresses[1] = self.address_for_coordinate([start_coord[0], start_coord[1], crossing_x, start_coord[3]])
- raise errors.UnsupportedFeatureError("Striping in vertical direction is not supported")
+ raise UnsupportedFeatureError("Striping in vertical direction is not supported")
if end_coord[1] > crossing_y:
addresses[2] = self.address_for_coordinate([start_coord[0], crossing_y, start_coord[2], start_coord[3]])
if end_coord[1] > crossing_y and end_coord[2] > crossing_x:
@@ -737,7 +738,7 @@ class Tensor:
if index < len(self.weight_compressed_offsets) - 1:
# There are no half-way points in the weights
if (depth % brick_depth) != 0:
- raise errors.UnsupportedFeatureError("Offset into weights must be aligned to a brick")
+ raise UnsupportedFeatureError("Offset into weights must be aligned to a brick")
return index
@@ -851,6 +852,32 @@ class Tensor:
__repr__ = __str__
+ def error(self, msg):
+ """
+ Raises a VelaError exception for errors encountered when parsing a Tensor
+
+ :param self: Tensor object that resulted in the error
+ :param msg: str object that contains a description of the specific error encountered
+ """
+
+ def _print_operators(ops):
+ lines = []
+ for idx, op in enumerate(ops):
+ op_type = getattr(op, "type", "Not an Operation")
+ op_id = getattr(op, "op_index", "-")
+ lines.append(f" {idx} = {op_type} ({op_id})")
+ return lines
+
+ lines = [f"Invalid {self.name} tensor. {msg}"]
+
+ lines += [" Driving operators:"]
+ lines += _print_operators(self.ops)
+
+ lines += [" Consuming operators:"]
+ lines += _print_operators(self.consumer_list)
+
+ raise VelaError("\n".join(lines))
+
def check_quantized_tens_scaling_equal(tens_a: Tensor, tens_b: Tensor) -> bool:
# checks that the scaling of two quantized tensors are equal