aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/architecture_features.py17
-rw-r--r--ethosu/vela/compiler_driver.py4
-rw-r--r--ethosu/vela/errors.py99
-rw-r--r--ethosu/vela/graph_optimiser.py14
-rw-r--r--ethosu/vela/high_level_command_to_npu_op.py3
-rw-r--r--ethosu/vela/shared_buffer_allocation.py4
-rw-r--r--ethosu/vela/tensor.py5
-rw-r--r--ethosu/vela/tensor_allocation.py12
-rw-r--r--ethosu/vela/tflite_reader.py5
-rw-r--r--ethosu/vela/tflite_writer.py3
-rw-r--r--ethosu/vela/vela.py4
-rw-r--r--ethosu/vela/weight_compressor.py8
12 files changed, 76 insertions, 102 deletions
diff --git a/ethosu/vela/architecture_features.py b/ethosu/vela/architecture_features.py
index 354ab12c..576f793a 100644
--- a/ethosu/vela/architecture_features.py
+++ b/ethosu/vela/architecture_features.py
@@ -531,13 +531,11 @@ class ArchitectureFeatures:
self._set_default_sys_config()
elif vela_config_files is None:
- raise CliOptionError("--config", vela_config_files, "CLI Option not specified")
+ raise CliOptionError("--config", vela_config_files, "Vela config file not specified")
else:
raise CliOptionError(
- "--system-config",
- self.system_config,
- "Section {} not found in Vela config file".format(sys_cfg_section),
+ "--system-config", self.system_config, f"Section {sys_cfg_section} not found in Vela config file",
)
# read the memory mode
@@ -559,11 +557,11 @@ class ArchitectureFeatures:
self._set_default_mem_mode()
elif vela_config_files is None:
- raise CliOptionError("--config", vela_config_files, "CLI Option not specified")
+ raise CliOptionError("--config", vela_config_files, "Vela config file not specified")
else:
raise CliOptionError(
- "--memory-mode", self.memory_mode, "Section {} not found in Vela config file".format(mem_mode_section),
+ "--memory-mode", self.memory_mode, f"Section {mem_mode_section} not found in Vela config file",
)
# override sram to onchipflash
@@ -645,9 +643,7 @@ class ArchitectureFeatures:
particular option then the key from the parent section is used, regardless of the parsing order
"""
if not self.vela_config.has_section(section):
- raise ConfigOptionError(
- "section", "{}. The section was not found in the Vela config file(s)".format(section)
- )
+ raise ConfigOptionError("section", f"{section}. The section was not found in the Vela config file(s)")
result = str(current_value)
if self.vela_config.has_option(section, "inherit"):
@@ -655,8 +651,7 @@ class ArchitectureFeatures:
# check for recursion loop
if inheritance_section == section:
raise ConfigOptionError(
- "inherit",
- "{}. This references its own section and recursion is not allowed".format(inheritance_section),
+ "inherit", f"{inheritance_section}. This references its own section and recursion is not allowed",
)
result = self._read_config(inheritance_section, key, result)
diff --git a/ethosu/vela/compiler_driver.py b/ethosu/vela/compiler_driver.py
index 6c7fdc1a..78d7f12a 100644
--- a/ethosu/vela/compiler_driver.py
+++ b/ethosu/vela/compiler_driver.py
@@ -267,9 +267,9 @@ def compiler_driver(nng, arch, options, scheduler_options):
alloc_results.append(alloc_success)
if not alloc_results[-1]:
raise VelaError(
- "Sram limit {} bytes, has been exceeded by the scratch fast tensor. "
+ f"Sram limit {arch.sram_size} bytes, has been exceeded by the scratch fast tensor. "
"Increasing the value of --weight-estimation-scaling may help to resolve the issue. "
- "See OPTIONS.md for more information.".format(arch.sram_size)
+ "See OPTIONS.md for more information"
)
else:
tensor_allocation.allocate_tensors(
diff --git a/ethosu/vela/errors.py b/ethosu/vela/errors.py
index 2a635d0e..b241db8e 100644
--- a/ethosu/vela/errors.py
+++ b/ethosu/vela/errors.py
@@ -23,7 +23,7 @@ class VelaError(Exception):
"""Base class for vela exceptions"""
def __init__(self, data):
- self.data = "Error: " + data
+ self.data = f"Error! {data}"
def __str__(self):
return repr(self.data)
@@ -33,14 +33,14 @@ class InputFileError(VelaError):
"""Raised when reading an input file results in errors"""
def __init__(self, file_name, msg):
- self.data = "Reading input file {}: {}".format(file_name, msg)
+ super().__init__(f"Reading input file '{file_name}': {msg}")
class UnsupportedFeatureError(VelaError):
"""Raised when the input network uses non-supported features that cannot be handled"""
def __init__(self, data):
- self.data = "Input network uses a feature that is currently not supported: {}".format(data)
+ super().__init__(f"Input network uses a feature that is currently not supported: {data}")
class CliOptionError(VelaError):
@@ -52,7 +52,7 @@ class CliOptionError(VelaError):
"""
def __init__(self, option, option_value, msg):
- self.data = "Incorrect argument to CLI option: {} = {}: {}".format(option, option_value, msg)
+ super().__init__(f"Incorrect argument to CLI option {option}={option_value}: {msg}")
class ConfigOptionError(VelaError):
@@ -64,18 +64,17 @@ class ConfigOptionError(VelaError):
"""
def __init__(self, option, option_value, option_valid_values=None):
- self.data = "Invalid configuration of {} = {}".format(option, option_value)
+ data = f"Invalid configuration of {option}={option_value}"
if option_valid_values is not None:
- self.data += " (must be {}).".format(option_valid_values)
- else:
- self.data += "."
+ data += f" (must be {option_valid_values})"
+ super().__init__(data)
class AllocationError(VelaError):
"""Raised when allocation fails"""
def __init__(self, msg):
- self.data = msg
+ super().__init__(f"Allocation failed: {msg}")
def OperatorError(op, msg):
@@ -86,36 +85,30 @@ def OperatorError(op, msg):
:param msg: str object that contains a description of the specific error encountered
"""
+ def _print_tensors(tensors):
+ lines = []
+ for idx, tens in enumerate(tensors):
+ if isinstance(tens, Tensor):
+ tens_name = tens.name
+ else:
+ tens_name = "Not a Tensor"
+ lines.append(f" {idx} = {tens_name}")
+ return lines
+
assert isinstance(op, Operation)
if op.op_index is None:
- data = "Invalid {} (name = {}) operator in the internal representation.".format(op.type, op.name)
+ lines = [f"Invalid {op.type} (name = {op.name}) operator in the internal representation. {msg}"]
else:
- data = "Invalid {} (op_index = {}) operator in the input network.".format(op.type, op.op_index)
-
- data += " {}\n".format(msg)
-
- data += " Input tensors:\n"
- for idx, tens in enumerate(op.inputs):
- if isinstance(tens, Tensor):
- tens_name = tens.name
- else:
- tens_name = "Not a Tensor"
-
- data += " {} = {}\n".format(idx, tens_name)
-
- data += " Output tensors:\n"
- for idx, tens in enumerate(op.outputs):
- if isinstance(tens, Tensor):
- tens_name = tens.name
- else:
- tens_name = "Not a Tensor"
+ lines = [f"Invalid {op.type} (op_index = {op.op_index}) operator in the input network. {msg}"]
- data += " {} = {}\n".format(idx, tens_name)
+ lines += [" Input tensors:"]
+ lines += _print_tensors(op.inputs)
- data = data[:-1] # remove last newline
+ lines += [" Output tensors:"]
+ lines += _print_tensors(op.outputs)
- raise VelaError(data)
+ raise VelaError("\n".join(lines))
def TensorError(tens, msg):
@@ -126,32 +119,26 @@ def TensorError(tens, msg):
:param msg: str object that contains a description of the specific error encountered
"""
- assert isinstance(tens, Tensor)
-
- data = "Invalid {} tensor. {}\n".format(tens.name, msg)
+ def _print_operators(ops):
+ lines = []
+ for idx, op in enumerate(ops):
+ if isinstance(op, Operation):
+ op_type = op.type
+ op_id = f"({op.op_index})"
+ else:
+ op_type = "Not an Operation"
+ op_id = ""
+ lines.append(f" {idx} = {op_type} {op_id}")
+ return lines
- data += " Driving operators:\n"
- for idx, op in enumerate(tens.ops):
- if isinstance(op, Operation):
- op_type = op.type
- op_id = op.op_index
- else:
- op_type = "Not an Operation"
- op_id = ""
-
- data += " {} = {} ({})\n".format(idx, op_type, op_id)
+ assert isinstance(tens, Tensor)
- data += " Consuming operators:\n"
- for idx, op in enumerate(tens.consumer_list):
- if isinstance(op, Operation):
- op_type = op.type
- op_id = op.op_index
- else:
- op_type = "Not an Operation"
- op_id = ""
+ lines = [f"Invalid {tens.name} tensor. {msg}"]
- data += " {} = {} ({})\n".format(idx, op_type, op_id)
+ lines += [" Driving operators:"]
+ lines += _print_operators(tens.ops)
- data = data[:-1] # remove last newline
+ lines += [" Consuming operators:"]
+ lines += _print_operators(tens.consumer_list)
- raise VelaError(data)
+ raise VelaError("\n".join(lines))
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 13f08f26..15d13522 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -42,6 +42,7 @@ from .tensor import create_const_tensor
from .tensor import create_reshape_tensor
from .tensor import QuantizationParameters
from .tensor import Tensor
+from .tflite_mapping import optype_to_builtintype
passthrough_nodes = (Op.Identity,)
@@ -157,7 +158,7 @@ def calc_padding_and_skirt(padding_type, kernel_size, stride, input_dims):
top_pad = 0
bottom_pad = 0
else:
- raise UnsupportedFeatureError("Unknown padding {}".format(str(padding_type)))
+ raise UnsupportedFeatureError(f"Unknown padding {padding_type.decode('utf-8')}")
padding = (top_pad, left_pad, bottom_pad, right_pad)
skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
return padding, skirt
@@ -168,19 +169,17 @@ def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_dim
if padding_type == b"SAME":
ypad = needed_total_padding(int(input_dims[1]) * upscaling_factor, int(stride[1]), int(kernel_height))
xpad = needed_total_padding(int(input_dims[2]) * upscaling_factor, int(stride[2]), int(kernel_width))
-
right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
left_pad = max(kernel_width - 1 - right_pad, 0)
top_pad = max(kernel_height - 1 - bottom_pad, 0)
-
elif padding_type == b"VALID":
right_pad = max(kernel_width - 2, 0)
bottom_pad = max(kernel_height - 2, 0)
left_pad = kernel_width - 1
top_pad = kernel_height - 1
else:
- assert 0, "Unknown padding"
+ raise UnsupportedFeatureError(f"Unknown padding {padding_type.decode('utf-8')}")
padding = (top_pad, left_pad, bottom_pad, right_pad)
skirt = padding
@@ -504,7 +503,7 @@ def add_padding_fields(op, arch, nng):
kernel_size = op.attrs["ksize"][1:3]
input_shape = op.inputs[0].shape
else:
- raise UnsupportedFeatureError("Unknown operation that uses padding: {}".format(op.type))
+ raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
if op.type == Op.Conv2DBackpropInputSwitchedBias:
upscaling_factor = op.outputs[0].shape[1] // input_shape[1]
@@ -560,9 +559,8 @@ def convert_depthwise_to_conv(op, arch, nng):
weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
else:
raise UnsupportedFeatureError(
- "Unsupported DepthwiseConv2d with depth_multiplier = {}, ifm channels = {}, ofm channels = {}".format(
- op.attrs["depth_multiplier"], ifm_tensor.shape[3], ofm_tensor.shape[3]
- )
+ f"Unsupported 'DEPTHWISE_CONV_2D' with depth_multiplier = {op.attrs['depth_multiplier']},",
+ f" ifm channels = {ifm_tensor.shape[3]}, ofm channels = {ofm_tensor.shape[3]}",
)
DebugDatabase.add_optimised(op, op)
return op
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index 9e0ed010..096a65cc 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -46,6 +46,7 @@ from .architecture_features import ArchitectureFeatures
from .architecture_features import Block
from .data_type import DataType
from .debug_database import DebugDatabase
+from .errors import UnsupportedFeatureError
from .high_level_command_stream import Box
from .high_level_command_stream import Command
from .high_level_command_stream import DMA
@@ -307,7 +308,7 @@ def create_npu_activation(op: Operation) -> NpuActivation:
elif faf == Op.LUT:
act_op = NpuActivationOp.TABLE_LOOKUP
elif not faf.is_relu_op():
- raise Exception("Unsupported fused_activation_function = " + faf.name)
+ raise UnsupportedFeatureError(f"Unsupported fused_activation_function: {faf.name}")
act = NpuActivation(act_op)
act.min = op.activation.min
diff --git a/ethosu/vela/shared_buffer_allocation.py b/ethosu/vela/shared_buffer_allocation.py
index 21b048bc..600b3170 100644
--- a/ethosu/vela/shared_buffer_allocation.py
+++ b/ethosu/vela/shared_buffer_allocation.py
@@ -26,7 +26,7 @@ from .architecture_features import ArchitectureFeatures
from .architecture_features import Block
from .architecture_features import SharedBufferArea
from .architecture_features import SHRAMElements
-from .errors import VelaError
+from .errors import AllocationError
from .ethos_u55_regs.ethos_u55_regs import resampling_mode
from .operation import Kernel
from .operation import NpuBlockType
@@ -259,7 +259,7 @@ def find_suitable_block_configs(arch, alloc: SharedBufferAllocation) -> List[Tup
if arch.override_block_config:
config = alloc.try_block(arch.override_block_config)
if config is None:
- raise VelaError("Block config override '{0}' cannot be allocated".format(arch.override_block_config))
+ raise AllocationError(f"Block config override '{arch.override_block_config}' cannot be allocated")
return [config]
# Constrain the search space if the OFM is smaller than the max block size
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index d75b7879..de97710a 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -29,6 +29,7 @@ from uuid import UUID
import numpy as np
+from . import errors # Import this way due to cyclic imports
from . import numeric_util
from .data_type import BaseType
from .data_type import DataType
@@ -610,7 +611,7 @@ class Tensor:
if end_coord[2] > crossing_x:
addresses[1] = self.address_for_coordinate([start_coord[0], start_coord[1], crossing_x, start_coord[3]])
- raise Exception("Striping in vertical direction is not supported")
+ raise errors.UnsupportedFeatureError("Striping in vertical direction is not supported")
if end_coord[1] > crossing_y:
addresses[2] = self.address_for_coordinate([start_coord[0], crossing_y, start_coord[2], start_coord[3]])
if end_coord[1] > crossing_y and end_coord[2] > crossing_x:
@@ -717,7 +718,7 @@ class Tensor:
if index < len(self.weight_compressed_offsets) - 1:
# There are no half-way points in the weights
if (depth % brick_depth) != 0:
- raise Exception("Offset into weights must be aligned to a brick")
+ raise errors.UnsupportedFeatureError("Offset into weights must be aligned to a brick")
return index
diff --git a/ethosu/vela/tensor_allocation.py b/ethosu/vela/tensor_allocation.py
index 9736ca22..7f66579e 100644
--- a/ethosu/vela/tensor_allocation.py
+++ b/ethosu/vela/tensor_allocation.py
@@ -91,7 +91,7 @@ def verify_alignment(live_ranges: LiveRangeGraph, alignment: int):
if not all(op and op.run_on_npu for op in tens.ops + tens.consumer_list):
# This is a CPU tensor, verify alignment
if tens.address % alignment != 0:
- raise AllocationError("Tensor {} not aligned to {} bytes".format(tens.name, alignment))
+ raise AllocationError(f"Tensor '{tens.name}' not aligned to {alignment} bytes")
def verify_allocation(live_ranges: LiveRangeGraph, alignment: int):
@@ -104,14 +104,8 @@ def verify_allocation(live_ranges: LiveRangeGraph, alignment: int):
overlap, tens_n, tens_m = n.overlaps_address(m)
if overlap and not (tens_n.equivalent(tens_m) and tens_n.address == tens_m.address):
raise AllocationError(
- "Overlapping buffers: {}: {} -> {} and {}: {} -> {}".format(
- n.name,
- tens_n.address,
- tens_n.address + n.size,
- m.name,
- tens_m.address,
- tens_m.address + m.size,
- )
+ f"Overlapping buffers: {n.name}: {tens_n.address} -> {tens_n.address + n.size}"
+ f" and {m.name}: {tens_m.address} -> {tens_m.address + m.size}"
)
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index df52478a..eff702b3 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -264,8 +264,9 @@ class TFLiteGraph:
def parse_operator_code(self, code):
c = code.BuiltinCode()
if c not in builtin_operator_map:
- msg = "The input file contains operator code {} which is currently not supported".format(c)
- raise InputFileError(self.name, msg)
+ raise InputFileError(
+ self.name, f"The input file contains operator code '{c}' which is currently not supported"
+ )
op_type, ser = builtin_operator_map[c]
custom_code = None
if c == BuiltinOperator.CUSTOM:
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py
index f747d471..06026ba5 100644
--- a/ethosu/vela/tflite_writer.py
+++ b/ethosu/vela/tflite_writer.py
@@ -21,6 +21,7 @@ import numpy as np
from flatbuffers import encode
from flatbuffers.builder import UOffsetTFlags
+from .errors import VelaError
from .nn_graph import PassPlacement
from .operation import Op
from .tensor import MemType
@@ -44,7 +45,7 @@ tflite_file_identifier = "TFL" + str(tflite_version)
def FinishWithFileIdentifier(self, rootTable, fid):
if fid is None or len(fid) != 4:
- raise Exception("fid must be 4 chars")
+ raise VelaError("FileIdentifier must be 4 chars")
flags = N.Uint8Flags
prepSize = 4
diff --git a/ethosu/vela/vela.py b/ethosu/vela/vela.py
index d27eef0e..bfc76ec9 100644
--- a/ethosu/vela/vela.py
+++ b/ethosu/vela/vela.py
@@ -55,7 +55,7 @@ def process(input_name, enable_debug_db, arch, model_reader_options, compiler_op
nng = model_reader.read_model(input_name, model_reader_options)
if not nng:
- raise InputFileError(input_name, "input file could not be read")
+ raise InputFileError(input_name, "Input file could not be read")
if compiler_options.verbose_operators:
nng.print_operators()
@@ -364,7 +364,7 @@ def main(args=None):
if args.config is not None:
for filename in args.config:
if not os.access(filename, os.R_OK):
- raise InputFileError(filename, "File not found or is not readable.")
+ raise InputFileError(filename, "File not found or is not readable")
sys.setrecursionlimit(args.recursion_limit)
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index fce17d19..188b16ad 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -452,18 +452,14 @@ def calc_scales_and_pack_biases(tens, arch, ofm_depth_step, rescale_for_faf=Fals
for weight_scale in weight_scales
]
else:
- raise UnsupportedFeatureError(
- "Compression of {} is not implemented; tensor: {}".format(ifm_dtype, tens.name)
- )
+ raise UnsupportedFeatureError(f"Compression of {ifm_dtype} is not implemented; Tensor: '{tens.name}'")
else:
if ifm_dtype == DataType.uint8:
scales = [np.double(ifm_scale * weight_scale * 0x3000) for weight_scale in weight_scales]
elif ifm_dtype == DataType.int8 or ifm_dtype == DataType.int16:
scales = [(np.double(ifm_scale * 0x3000) * np.double(weight_scale)) for weight_scale in weight_scales]
else:
- raise UnsupportedFeatureError(
- "Compression of {} is not implemented; tensor: {}".format(ifm_dtype, tens.name)
- )
+ raise UnsupportedFeatureError(f"Compression of {ifm_dtype} is not implemented; Tensor: '{tens.name}'")
# quantise all of the weight scales into (scale_factor, shift)
if ifm_dtype == DataType.int16: