aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela')
-rw-r--r--ethosu/vela/errors.py69
-rw-r--r--ethosu/vela/mark_tensors.py3
-rw-r--r--ethosu/vela/operation.py29
-rw-r--r--ethosu/vela/tensor.py33
-rw-r--r--ethosu/vela/tflite_reader.py3
5 files changed, 61 insertions, 76 deletions
diff --git a/ethosu/vela/errors.py b/ethosu/vela/errors.py
index b241db8e..04468c90 100644
--- a/ethosu/vela/errors.py
+++ b/ethosu/vela/errors.py
@@ -15,8 +15,6 @@
# limitations under the License.
# Description:
# Defines custom exceptions.
-from .operation import Operation
-from .tensor import Tensor
class VelaError(Exception):
@@ -75,70 +73,3 @@ class AllocationError(VelaError):
def __init__(self, msg):
super().__init__(f"Allocation failed: {msg}")
-
-
-def OperatorError(op, msg):
- """
- Raises a VelaError exception for errors encountered when parsing an Operation
-
- :param op: Operation object that resulted in the error
- :param msg: str object that contains a description of the specific error encountered
- """
-
- def _print_tensors(tensors):
- lines = []
- for idx, tens in enumerate(tensors):
- if isinstance(tens, Tensor):
- tens_name = tens.name
- else:
- tens_name = "Not a Tensor"
- lines.append(f" {idx} = {tens_name}")
- return lines
-
- assert isinstance(op, Operation)
-
- if op.op_index is None:
- lines = [f"Invalid {op.type} (name = {op.name}) operator in the internal representation. {msg}"]
- else:
- lines = [f"Invalid {op.type} (op_index = {op.op_index}) operator in the input network. {msg}"]
-
- lines += [" Input tensors:"]
- lines += _print_tensors(op.inputs)
-
- lines += [" Output tensors:"]
- lines += _print_tensors(op.outputs)
-
- raise VelaError("\n".join(lines))
-
-
-def TensorError(tens, msg):
- """
- Raises a VelaError exception for errors encountered when parsing a Tensor
-
- :param tens: Tensor object that resulted in the error
- :param msg: str object that contains a description of the specific error encountered
- """
-
- def _print_operators(ops):
- lines = []
- for idx, op in enumerate(ops):
- if isinstance(op, Operation):
- op_type = op.type
- op_id = f"({op.op_index})"
- else:
- op_type = "Not an Operation"
- op_id = ""
- lines.append(f" {idx} = {op_type} {op_id}")
- return lines
-
- assert isinstance(tens, Tensor)
-
- lines = [f"Invalid {tens.name} tensor. {msg}"]
-
- lines += [" Driving operators:"]
- lines += _print_operators(tens.ops)
-
- lines += [" Consuming operators:"]
- lines += _print_operators(tens.consumer_list)
-
- raise VelaError("\n".join(lines))
diff --git a/ethosu/vela/mark_tensors.py b/ethosu/vela/mark_tensors.py
index 723bd876..5a475841 100644
--- a/ethosu/vela/mark_tensors.py
+++ b/ethosu/vela/mark_tensors.py
@@ -15,7 +15,6 @@
# limitations under the License.
# Description:
# Mark purpose and select formats for Tensors.
-from .errors import OperatorError
from .operation import CustomType
from .operation import Op
from .rewrite_graph import visit_graph_post_order
@@ -81,7 +80,7 @@ def rewrite_mark_tensor_purpose(op, arch):
scratch_tensor.purpose = TensorPurpose.Scratch
if scratch_tensor is None:
- OperatorError(op, "Scratch tensor not found.")
+ op.error("Scratch tensor not found.")
def mark_tensor_purpose(nng, arch, verbose_tensor_purpose=False):
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index afc02d41..30c32acc 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -24,6 +24,7 @@ from typing import List
from typing import Optional
from typing import TYPE_CHECKING
+from .errors import VelaError
from .numeric_util import full_shape
if TYPE_CHECKING:
@@ -668,3 +669,31 @@ class Operation:
if self.forced_output_quantization is not None:
return self.forced_output_quantization
return self.ofm.quantization
+
+ def error(self, msg):
+ """
+ Raises a VelaError exception for errors encountered when parsing an Operation
+
+ :param self: Operation object that resulted in the error
+ :param msg: str object that contains a description of the specific error encountered
+ """
+
+ def _print_tensors(tensors):
+ lines = []
+ for idx, tens in enumerate(tensors):
+ tens_name = getattr(tens, "name", "Not a Tensor")
+ lines.append(f" {idx} = {tens_name}")
+ return lines
+
+ if self.op_index is None:
+ lines = [f"Invalid {self.type} (name = {self.name}) operator in the internal representation. {msg}"]
+ else:
+ lines = [f"Invalid {self.type} (op_index = {self.op_index}) operator in the input network. {msg}"]
+
+ lines += [" Input tensors:"]
+ lines += _print_tensors(self.inputs)
+
+ lines += [" Output tensors:"]
+ lines += _print_tensors(self.outputs)
+
+ raise VelaError("\n".join(lines))
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index c1443b3b..69618d2c 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -31,10 +31,11 @@ from uuid import UUID
import numpy as np
-from . import errors # Import this way due to cyclic imports
from . import numeric_util
from .data_type import BaseType
from .data_type import DataType
+from .errors import UnsupportedFeatureError
+from .errors import VelaError
from .ethos_u55_regs.ethos_u55_regs import resampling_mode
from .operation import Op
from .operation import Operation
@@ -630,7 +631,7 @@ class Tensor:
if end_coord[2] > crossing_x:
addresses[1] = self.address_for_coordinate([start_coord[0], start_coord[1], crossing_x, start_coord[3]])
- raise errors.UnsupportedFeatureError("Striping in vertical direction is not supported")
+ raise UnsupportedFeatureError("Striping in vertical direction is not supported")
if end_coord[1] > crossing_y:
addresses[2] = self.address_for_coordinate([start_coord[0], crossing_y, start_coord[2], start_coord[3]])
if end_coord[1] > crossing_y and end_coord[2] > crossing_x:
@@ -737,7 +738,7 @@ class Tensor:
if index < len(self.weight_compressed_offsets) - 1:
# There are no half-way points in the weights
if (depth % brick_depth) != 0:
- raise errors.UnsupportedFeatureError("Offset into weights must be aligned to a brick")
+ raise UnsupportedFeatureError("Offset into weights must be aligned to a brick")
return index
@@ -851,6 +852,32 @@ class Tensor:
__repr__ = __str__
+ def error(self, msg):
+ """
+ Raises a VelaError exception for errors encountered when parsing a Tensor
+
+ :param self: Tensor object that resulted in the error
+ :param msg: str object that contains a description of the specific error encountered
+ """
+
+ def _print_operators(ops):
+ lines = []
+ for idx, op in enumerate(ops):
+ op_type = getattr(op, "type", "Not an Operation")
+ op_id = getattr(op, "op_index", "-")
+ lines.append(f" {idx} = {op_type} ({op_id})")
+ return lines
+
+ lines = [f"Invalid {self.name} tensor. {msg}"]
+
+ lines += [" Driving operators:"]
+ lines += _print_operators(self.ops)
+
+ lines += [" Consuming operators:"]
+ lines += _print_operators(self.consumer_list)
+
+ raise VelaError("\n".join(lines))
+
def check_quantized_tens_scaling_equal(tens_a: Tensor, tens_b: Tensor) -> bool:
# checks that the scaling of two quantized tensors are equal
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index eff702b3..21ff8873 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -20,7 +20,6 @@ import os.path
import numpy as np
from .errors import InputFileError
-from .errors import TensorError
from .nn_graph import Graph
from .nn_graph import Subgraph
from .operation import create_activation_function
@@ -77,7 +76,7 @@ class TFLiteSubgraph:
# Fix up tensors without operations. Generate either Placeholder or Constant ops
for tens in self.inputs:
if tens.ops != []:
- TensorError(tens, "This subgraph input tensor has unexpected driving operators.")
+ tens.error("This subgraph input tensor has unexpected driving operators.")
op = Operation(Op.Placeholder, tens.name)
op.set_output_tensor(tens)