aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tensor.py
diff options
context:
space:
mode:
authorMichael McGeagh <michael.mcgeagh@arm.com>2020-12-16 11:33:21 +0000
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2020-12-18 15:02:32 +0000
commit528a56df829b65f7a2c61953650b123c461095f7 (patch)
treee616cdfff4b40a29d362bab51e6641ec364ae115 /ethosu/vela/tensor.py
parent1a184e4a50ad2f3cc8c5bfcd23e0f875c089314c (diff)
downloadethos-u-vela-528a56df829b65f7a2c61953650b123c461095f7.tar.gz
vela: Move special error cases
Due to an issue with potential cyclical imports, especially when running individual parts of vela standalone for example with pytest, the specialised error functions are moved out of errors.py to their respective locations. The use of getattr over isinstance prevents the need to import the tensor/operator class causing the cyclical import issue. Signed-off-by: Michael McGeagh <michael.mcgeagh@arm.com> Change-Id: If8cee4b1a2562660c6a47e1c7aeb5d7fd4dd1fca
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r--ethosu/vela/tensor.py33
1 files changed, 30 insertions, 3 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index c1443b3b..69618d2c 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -31,10 +31,11 @@ from uuid import UUID
import numpy as np
-from . import errors # Import this way due to cyclic imports
from . import numeric_util
from .data_type import BaseType
from .data_type import DataType
+from .errors import UnsupportedFeatureError
+from .errors import VelaError
from .ethos_u55_regs.ethos_u55_regs import resampling_mode
from .operation import Op
from .operation import Operation
@@ -630,7 +631,7 @@ class Tensor:
if end_coord[2] > crossing_x:
addresses[1] = self.address_for_coordinate([start_coord[0], start_coord[1], crossing_x, start_coord[3]])
- raise errors.UnsupportedFeatureError("Striping in vertical direction is not supported")
+ raise UnsupportedFeatureError("Striping in vertical direction is not supported")
if end_coord[1] > crossing_y:
addresses[2] = self.address_for_coordinate([start_coord[0], crossing_y, start_coord[2], start_coord[3]])
if end_coord[1] > crossing_y and end_coord[2] > crossing_x:
@@ -737,7 +738,7 @@ class Tensor:
if index < len(self.weight_compressed_offsets) - 1:
# There are no half-way points in the weights
if (depth % brick_depth) != 0:
- raise errors.UnsupportedFeatureError("Offset into weights must be aligned to a brick")
+ raise UnsupportedFeatureError("Offset into weights must be aligned to a brick")
return index
@@ -851,6 +852,32 @@ class Tensor:
__repr__ = __str__
+ def error(self, msg):
+ """
+ Raises a VelaError exception for errors encountered when parsing a Tensor
+
+ :param self: Tensor object that resulted in the error
+ :param msg: str object that contains a description of the specific error encountered
+ """
+
+ def _print_operators(ops):
+ lines = []
+ for idx, op in enumerate(ops):
+ op_type = getattr(op, "type", "Not an Operation")
+ op_id = getattr(op, "op_index", "-")
+ lines.append(f" {idx} = {op_type} ({op_id})")
+ return lines
+
+ lines = [f"Invalid {self.name} tensor. {msg}"]
+
+ lines += [" Driving operators:"]
+ lines += _print_operators(self.ops)
+
+ lines += [" Consuming operators:"]
+ lines += _print_operators(self.consumer_list)
+
+ raise VelaError("\n".join(lines))
+
def check_quantized_tens_scaling_equal(tens_a: Tensor, tens_b: Tensor) -> bool:
# checks that the scaling of two quantized tensors are equal