aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLouis Verhaard <louis.verhaard@arm.com>2020-05-25 15:05:26 +0200
committerTim Hall <tim.hall@arm.com>2020-06-18 17:53:52 +0100
commit7db78969dc8ead72f3ded81b6d2a6a7ed798ea62 (patch)
tree011bcf579cc8e0f007f9564a98cc5c05df34322b
parent78792223369fa34dacd0e69e189af035283da2ae (diff)
downloadethos-u-vela-7db78969dc8ead72f3ded81b6d2a6a7ed798ea62.tar.gz
MLBEDSW-2067: added custom exceptions
Added custom exceptions to handle different types of input errors. Also performed minor formatting changes using flake8/black. Change-Id: Ie5b05361507d5e569aff045757aec0a4a755ae98 Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
-rw-r--r--ethosu/vela/architecture_features.py5
-rw-r--r--ethosu/vela/errors.py48
-rw-r--r--ethosu/vela/graph_optimiser.py19
-rw-r--r--ethosu/vela/model_reader.py27
-rw-r--r--ethosu/vela/shared_buffer_allocation.py8
-rw-r--r--ethosu/vela/test/test_model_reader.py40
-rw-r--r--ethosu/vela/tflite_reader.py14
-rw-r--r--ethosu/vela/vela.py4
-rw-r--r--ethosu/vela/weight_compressor.py9
9 files changed, 139 insertions, 35 deletions
diff --git a/ethosu/vela/architecture_features.py b/ethosu/vela/architecture_features.py
index c712588f..1bf9d950 100644
--- a/ethosu/vela/architecture_features.py
+++ b/ethosu/vela/architecture_features.py
@@ -21,6 +21,7 @@ from configparser import ConfigParser
import numpy as np
+from .errors import OptionError
from .numeric_util import round_up
from .numeric_util import round_up_divide
from .operation import NpuBlockType
@@ -158,7 +159,7 @@ Note the difference between ArchitectureFeatures and CompilerOptions
self.vela_config = vela_config
self.accelerator_config = accelerator_config
if self.accelerator_config not in ArchitectureFeatures.accelerator_configs:
- raise Exception("Unknown accelerator configuration " + self.accelerator_config)
+ raise OptionError("--accelerator-config", self.accelerator_config, "Unknown accelerator configuration")
accel_config = ArchitectureFeatures.accelerator_configs[self.accelerator_config]
self.config = accel_config
@@ -564,7 +565,7 @@ Note the difference between ArchitectureFeatures and CompilerOptions
else:
section_key = "SysConfig." + self.system_config
if section_key not in self.vela_config:
- raise Exception("Unknown system configuration " + self.system_config)
+ raise OptionError("--system-config", self.system_config, "Unknown system configuration")
try:
self.npu_clock = float(self.__sys_config("npu_freq", "500e6"))
diff --git a/ethosu/vela/errors.py b/ethosu/vela/errors.py
new file mode 100644
index 00000000..efe64d5c
--- /dev/null
+++ b/ethosu/vela/errors.py
@@ -0,0 +1,48 @@
+# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Description:
+# Defines custom exceptions.
+
+
+class VelaError(Exception):
+ """Base class for vela exceptions"""
+
+ def __init__(self, data):
+ self.data = data
+
+ def __str__(self):
+ return repr(self.data)
+
+
+class InputFileError(VelaError):
+ """Raised when reading the input file results in errors"""
+
+ def __init__(self, file_name, msg):
+ self.data = "Error reading {}: {}".format(file_name, msg)
+
+
+class UnsupportedFeatureError(VelaError):
+ """Raised when the input file uses non-supported features that cannot be handled"""
+
+ def __init__(self, data):
+ self.data = "The input file uses a feature that is currently not supported: {}".format(data)
+
+
+class OptionError(VelaError):
+ """Raised when an incorrect command line option is used"""
+
+ def __init__(self, option, option_value, msg):
+ self.data = "Incorrect argument: {} {}: {}".format(option, option_value, msg)
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 351716e0..72bb486c 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -22,6 +22,7 @@ import numpy as np
from . import rewrite_graph
from .data_type import DataType
+from .errors import UnsupportedFeatureError
from .operation import NpuBlockType
from .operation import Operation
from .tensor import Tensor
@@ -124,7 +125,7 @@ def calc_padding_and_skirt(padding_type, kernel_size, stride, input_dims):
top_pad = 0
bottom_pad = 0
else:
- assert 0, "Unknown padding"
+ raise UnsupportedFeatureError("Unknown padding {}".format(str(padding_type)))
padding = (top_pad, left_pad, bottom_pad, right_pad)
skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
return padding, skirt
@@ -214,7 +215,7 @@ def fixup_unpack_output(tens, arch):
if op.type == "StridedSlice":
new_axis_mask = op.attrs["new_axis_mask"]
shrink_axis_mask = op.attrs["shrink_axis_mask"]
- ellipsis_mask = op.attrs["ellipsis_mask"]
+ ellipsis_mask = op.attrs["ellipsis_mask"]
if (new_axis_mask != 0 and shrink_axis_mask != 0) or ellipsis_mask != 0:
# Not supported, will be put on CPU
@@ -243,7 +244,7 @@ def fixup_unpack_output(tens, arch):
n += 1
new_axis_mask &= new_axis_mask - 1
axis = int(math.log2(prev_mask - new_axis_mask))
- reshape_input_shape = reshape_input_shape[:axis] + reshape_input_shape[(axis + 1):]
+ reshape_input_shape = reshape_input_shape[:axis] + reshape_input_shape[(axis + 1) :]
new_axis_mask >>= 1
assert len(tens.shape) == (len(op.inputs[0].shape) + n)
@@ -288,7 +289,7 @@ def add_padding_fields(op, arch):
kernel_size = op.attrs["ksizes"][1:3]
input_shape = op.inputs[0].shape
else:
- assert 0, "Unknown operation that uses padding"
+ raise UnsupportedFeatureError("Unknown operation that uses padding: {}".format(op.type))
padding, skirt = calc_padding_and_skirt(op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape)
op.attrs["explicit_padding"] = padding
@@ -312,7 +313,9 @@ fc_op = set(
)
)
depthwise_op = set(("DepthwiseConv2dNative", "DepthwiseConv2dBiasAct",))
-pool_op = set(("AvgPool", "MaxPool", "QuantizedAvgPool", "QuantizedMaxPool", "AvgPoolAct", "MaxPoolAct", "ResizeBilinear",))
+pool_op = set(
+ ("AvgPool", "MaxPool", "QuantizedAvgPool", "QuantizedMaxPool", "AvgPoolAct", "MaxPoolAct", "ResizeBilinear",)
+)
elementwise_op = set(("AddAct", "MulAct", "SubAct", "Maximum", "Minimum", "LeakyRelu", "Abs"))
binary_elementwise_op = set(("AddAct", "MulAct", "SubAct", "Maximum", "Minimum"))
activation_ops = set(("Relu", "Relu6", "ReluN1To1", "Sigmoid", "Tanh"))
@@ -373,13 +376,11 @@ def convert_depthwise_to_conv(op, arch):
weight_tensor.quant_values.shape
)
else:
- print(
- "Error: Unsupported DepthwiseConv2d with depth_multiplier = {0}, "
- "ifm channels = {1}, ofm channels = {2}".format(
+ raise UnsupportedFeatureError(
+ "Unsupported DepthwiseConv2d with depth_multiplier = {}, ifm channels = {}, ofm channels = {}".format(
op.attrs["depth_multiplier"], ifm_tensor.shape[3], ofm_tensor.shape[3]
)
)
- assert False
return op
diff --git a/ethosu/vela/model_reader.py b/ethosu/vela/model_reader.py
index d1cdc9bd..6deb2538 100644
--- a/ethosu/vela/model_reader.py
+++ b/ethosu/vela/model_reader.py
@@ -15,6 +15,9 @@
# limitations under the License.
# Description:
# Dispatcher for reading a neural network model.
+from . import tflite_reader
+from .errors import InputFileError
+from .errors import VelaError
class ModelReaderOptions:
@@ -29,15 +32,17 @@ class ModelReaderOptions:
def read_model(fname, options, feed_dict={}, output_node_names=[], initialisation_nodes=[]):
if fname.endswith(".tflite"):
- from . import tflite_reader
-
- nng = tflite_reader.read_tflite(
- fname,
- options.batch_size,
- feed_dict=feed_dict,
- output_node_names=output_node_names,
- initialisation_nodes=initialisation_nodes,
- )
+ try:
+ return tflite_reader.read_tflite(
+ fname,
+ options.batch_size,
+ feed_dict=feed_dict,
+ output_node_names=output_node_names,
+ initialisation_nodes=initialisation_nodes,
+ )
+ except VelaError as e:
+ raise e
+ except Exception as e:
+ raise InputFileError(fname, str(e))
else:
- assert 0, "Unknown model format"
- return nng
+ raise InputFileError(fname, "Unknown input file format. Only .tflite files are supported")
diff --git a/ethosu/vela/shared_buffer_allocation.py b/ethosu/vela/shared_buffer_allocation.py
index 335b863f..2bfe5941 100644
--- a/ethosu/vela/shared_buffer_allocation.py
+++ b/ethosu/vela/shared_buffer_allocation.py
@@ -22,6 +22,7 @@ from .architecture_features import Block
from .architecture_features import Kernel
from .architecture_features import SharedBufferArea
from .architecture_features import SHRAMElements
+from .errors import OptionError
from .operation import NpuBlockType
@@ -163,8 +164,11 @@ def find_block_configs_suitable_for_pass_and_shared_buffer(arch, ps):
if arch.override_block_config:
config = alloc.try_block(arch.override_block_config)
- assert config, "Block config override cannot be used"
- return [config]
+ raise OptionError(
+ "--force-block-config",
+ str(arch.override_block_config),
+ "This forced block config value cannot be used; it is not compatible",
+ )
# Constrain the search space if the OFM is smaller than the max block size
# - Add other block search constraints here if required
diff --git a/ethosu/vela/test/test_model_reader.py b/ethosu/vela/test/test_model_reader.py
new file mode 100644
index 00000000..ee9a51e8
--- /dev/null
+++ b/ethosu/vela/test/test_model_reader.py
@@ -0,0 +1,40 @@
+# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Description:
+# Unit tests for model_reader.
+import pytest
+from ethosu.vela import model_reader
+from ethosu.vela.errors import InputFileError
+
+
+def test_read_model_incorrect_extension(tmpdir):
+ # Tests read_model with a file name that does not end with .tflite
+ with pytest.raises(InputFileError):
+ model_reader.read_model("no_tflite_file.txt", model_reader.ModelReaderOptions())
+
+
+def test_read_model_corrupt_contents(tmpdir):
+ # Tests read_model with a corrupt .tflite file
+ fname = tmpdir.join("corrupt.tflite")
+ fname.write("abcde1234")
+ with pytest.raises(InputFileError):
+ model_reader.read_model(fname.strpath, model_reader.ModelReaderOptions())
+
+
+def test_read_model_file_not_found(tmpdir):
+ # Tests read_model with a .tflite file that does not exist
+ with pytest.raises(InputFileError):
+ model_reader.read_model("non_existing.tflite", model_reader.ModelReaderOptions())
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index 7e158aac..850690f2 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -19,6 +19,7 @@ import os.path
import numpy as np
+from .errors import UnsupportedFeatureError
from .nn_graph import Graph
from .nn_graph import Subgraph
from .operation import Operation
@@ -147,18 +148,17 @@ class TFLiteSubgraph:
if op_type.startswith("ResizeBilinear"):
upscaled_shape = [op.inputs[0].shape[1] * 2, op.inputs[0].shape[2] * 2]
out_shape = op.outputs[0].shape[1:3]
- if not op.attrs['align_corners'] and out_shape == upscaled_shape:
+ if not op.attrs["align_corners"] and out_shape == upscaled_shape:
# this means the output is supposed to be a x2 upscale,
# so we need to do SAME padding
- op.attrs.update({'padding': b'SAME'})
- elif (op.attrs['align_corners']
- and out_shape == [upscaled_shape[0] - 1, upscaled_shape[1] - 1]):
+ op.attrs.update({"padding": b"SAME"})
+ elif op.attrs["align_corners"] and out_shape == [upscaled_shape[0] - 1, upscaled_shape[1] - 1]:
# here we can just run the avg pool without padding and
# produce a (M * 2 - 1, N * 2 - 1) sized output
- op.attrs.update({'padding': b'VALID'})
+ op.attrs.update({"padding": b"VALID"})
else:
- assert False, "Only 2x upscaling is supported"
- op.attrs.update({'filter_width': 2, 'filter_height': 2, 'stride_w': 1, 'stride_h': 1,})
+ raise UnsupportedFeatureError("ResizeBilinear: Only 2x upscaling is supported")
+ op.attrs.update({"filter_width": 2, "filter_height": 2, "stride_w": 1, "stride_h": 1})
if "stride_w" in op.attrs:
op.attrs["strides"] = (1, op.attrs["stride_h"], op.attrs["stride_w"], 1)
diff --git a/ethosu/vela/vela.py b/ethosu/vela/vela.py
index 49f8c26c..bd5409ce 100644
--- a/ethosu/vela/vela.py
+++ b/ethosu/vela/vela.py
@@ -31,6 +31,7 @@ from . import scheduler
from . import stats_writer
from . import tflite_writer
from ._version import __version__
+from .errors import InputFileError
from .nn_graph import PassPlacement
from .nn_graph import TensorAllocator
from .scheduler import ParetoMetric
@@ -44,8 +45,7 @@ def process(fname, arch, model_reader_options, compiler_options, scheduler_optio
nng = model_reader.read_model(fname, model_reader_options)
if not nng:
- print("reading of", fname, "failed")
- assert False
+ raise InputFileError(fname, "input file could not be read")
if compiler_options.verbose_operators:
nng.print_operators()
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index 04d684e6..a81b1fb4 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -23,6 +23,7 @@ from ethosu import mlw_codec
from .architecture_features import Block
from .data_type import DataType
+from .errors import UnsupportedFeatureError
from .nn_graph import SchedulingStrategy
from .numeric_util import round_up
from .operation import NpuBlockType
@@ -292,14 +293,18 @@ def calc_scales_and_pack_biases(tens, arch, oc_quantum, rescale_for_faf=False):
for weight_scale in weight_scales
]
else:
- assert False, str(ifm_dtype) + " not implemented"
+ raise UnsupportedFeatureError(
+ "Compression of {} is not implemented; tensor: {}".format(ifm_dtype, tens.name)
+ )
else:
if ifm_dtype == DataType.uint8:
scales = [np.double(ifm_scale * weight_scale * 0x3000) for weight_scale in weight_scales]
elif ifm_dtype == DataType.int8 or ifm_dtype == DataType.int16:
scales = [(np.double(ifm_scale * 0x3000) * np.double(weight_scale)) for weight_scale in weight_scales]
else:
- assert False, str(ifm_dtype) + " not implemented"
+ raise UnsupportedFeatureError(
+ "Compression of {} is not implemented; tensor: {}".format(ifm_dtype, tens.name)
+ )
# quantise all of the weight scales into (scale_factor, shift)
if ifm_dtype == DataType.int16: