diff options
Diffstat (limited to 'ethosu/vela/weight_compressor.py')
-rw-r--r-- | ethosu/vela/weight_compressor.py | 17 |
1 files changed, 14 insertions, 3 deletions
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py index 7e03a94d..45427a1a 100644 --- a/ethosu/vela/weight_compressor.py +++ b/ethosu/vela/weight_compressor.py @@ -23,7 +23,6 @@ import numpy as np from .architecture_features import Accelerator from .architecture_features import ArchitectureFeatures from .data_type import DataType -from .errors import typecheck from .errors import UnsupportedFeatureError from .nn_graph import SchedulingStrategy from .numeric_util import round_up @@ -45,7 +44,6 @@ WeightCompressionConfig = namedtuple( ) -@typecheck def encode_weights( accelerator: Accelerator, weights_volume: np.ndarray, @@ -68,6 +66,15 @@ def encode_weights( :return: a bytearray of compressed weights """ + # Check arg types + assert isinstance(accelerator, Accelerator) + assert isinstance(weights_volume, np.ndarray) + assert isinstance(dilation_xy, tuple) + assert isinstance(ifm_bitdepth, int) + assert isinstance(ofm_block_depth, int) + assert isinstance(is_depthwise, bool) + assert isinstance(is_partkernel, bool) + # Checks for weight layout assert len(weights_volume.shape) == 4, "weights ndarray should have a shape of 4" @@ -94,7 +101,6 @@ def encode_weights( return encoded_stream -@typecheck def encode_bias(bias: np.int64, scale: int, shift: int): """ Public facing API to pack bias and scale values as required by the hardware @@ -103,6 +109,11 @@ def encode_bias(bias: np.int64, scale: int, shift: int): :param shift: 6bit shift value :return: packed 80bit [0(2-bits),shift(6-bits),scale(32-bits),bias(40-bits)] """ + # Check arg types + assert isinstance(bias, np.int64) + assert isinstance(scale, int) + assert isinstance(shift, int) + assert -(1 << (40 - 1)) <= bias < (1 << (40 - 1)) # signed 40-bit range assert 0 <= scale < (1 << 32) # unsigned 32-bit range assert 0 <= shift < (1 << 6) # unsigned 6-bit range |