From 528a56df829b65f7a2c61953650b123c461095f7 Mon Sep 17 00:00:00 2001 From: Michael McGeagh Date: Wed, 16 Dec 2020 11:33:21 +0000 Subject: vela: Move special error cases Due to an issue with potential cyclical imports, especially when running individual parts of vela standalone for example with pytest, the specialised error functions are moved out of errors.py to their respective locations. The use of getattr over isinstance prevents the need to import the tensor/operator class causing the cyclical import issue. Signed-off-by: Michael McGeagh Change-Id: If8cee4b1a2562660c6a47e1c7aeb5d7fd4dd1fca --- ethosu/vela/operation.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) (limited to 'ethosu/vela/operation.py') diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index afc02d41..30c32acc 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -24,6 +24,7 @@ from typing import List from typing import Optional from typing import TYPE_CHECKING +from .errors import VelaError from .numeric_util import full_shape if TYPE_CHECKING: @@ -668,3 +669,31 @@ class Operation: if self.forced_output_quantization is not None: return self.forced_output_quantization return self.ofm.quantization + + def error(self, msg): + """ + Raises a VelaError exception for errors encountered when parsing an Operation + + :param self: Operation object that resulted in the error + :param msg: str object that contains a description of the specific error encountered + """ + + def _print_tensors(tensors): + lines = [] + for idx, tens in enumerate(tensors): + tens_name = getattr(tens, "name", "Not a Tensor") + lines.append(f" {idx} = {tens_name}") + return lines + + if self.op_index is None: + lines = [f"Invalid {self.type} (name = {self.name}) operator in the internal representation. {msg}"] + else: + lines = [f"Invalid {self.type} (op_index = {self.op_index}) operator in the input network. {msg}"] + + lines += [" Input tensors:"] + lines += _print_tensors(self.inputs) + + lines += [" Output tensors:"] + lines += _print_tensors(self.outputs) + + raise VelaError("\n".join(lines)) -- cgit v1.2.1