diff options
Diffstat (limited to 'ethosu/vela')
-rw-r--r-- | ethosu/vela/errors.py | 18 | ||||
-rw-r--r-- | ethosu/vela/weight_compressor.py | 17 |
2 files changed, 14 insertions, 21 deletions
diff --git a/ethosu/vela/errors.py b/ethosu/vela/errors.py index 59740aac..2c93fbc6 100644 --- a/ethosu/vela/errors.py +++ b/ethosu/vela/errors.py @@ -15,7 +15,6 @@ # limitations under the License. # Description: # Defines custom exceptions. -import inspect import sys from .operation import Operation @@ -122,20 +121,3 @@ def TensorError(tens, msg): print("Error: {}".format(data)) sys.exit(1) - - -def typecheck(func): - def wrapper(*args, **kwargs): - fsig = inspect.signature(func) - args_zipped = zip(kwargs.values(), fsig.parameters.keys()) - for actual, expected in args_zipped: - expected_type = fsig.parameters[expected].annotation - actual_type = type(actual) - if expected_type is inspect.Parameter.empty: - raise TypeError("Please provide type info for {}, hint = {}".format(expected, actual_type)) - if expected_type is not actual_type: - raise TypeError("expected : {}, but got {}".format(expected_type, actual_type)) - # Actual execution - return func(*args, **kwargs) - - return wrapper diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py index 7e03a94d..45427a1a 100644 --- a/ethosu/vela/weight_compressor.py +++ b/ethosu/vela/weight_compressor.py @@ -23,7 +23,6 @@ import numpy as np from .architecture_features import Accelerator from .architecture_features import ArchitectureFeatures from .data_type import DataType -from .errors import typecheck from .errors import UnsupportedFeatureError from .nn_graph import SchedulingStrategy from .numeric_util import round_up @@ -45,7 +44,6 @@ WeightCompressionConfig = namedtuple( ) -@typecheck def encode_weights( accelerator: Accelerator, weights_volume: np.ndarray, @@ -68,6 +66,15 @@ def encode_weights( :return: a bytearray of compressed weights """ + # Check arg types + assert isinstance(accelerator, Accelerator) + assert isinstance(weights_volume, np.ndarray) + assert isinstance(dilation_xy, tuple) + assert isinstance(ifm_bitdepth, int) + assert isinstance(ofm_block_depth, int) + assert isinstance(is_depthwise, bool) + assert isinstance(is_partkernel, bool) + # Checks for weight layout assert len(weights_volume.shape) == 4, "weights ndarray should have a shape of 4" @@ -94,7 +101,6 @@ def encode_weights( return encoded_stream -@typecheck def encode_bias(bias: np.int64, scale: int, shift: int): """ Public facing API to pack bias and scale values as required by the hardware @@ -103,6 +109,11 @@ def encode_bias(bias: np.int64, scale: int, shift: int): :param shift: 6bit shift value :return: packed 80bit [0(2-bits),shift(6-bits),scale(32-bits),bias(40-bits)] """ + # Check arg types + assert isinstance(bias, np.int64) + assert isinstance(scale, int) + assert isinstance(shift, int) + assert -(1 << (40 - 1)) <= bias < (1 << (40 - 1)) # signed 40-bit range assert 0 <= scale < (1 << 32) # unsigned 32-bit range assert 0 <= shift < (1 << 6) # unsigned 6-bit range |