aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/errors.py18
-rw-r--r--ethosu/vela/weight_compressor.py17
2 files changed, 14 insertions, 21 deletions
diff --git a/ethosu/vela/errors.py b/ethosu/vela/errors.py
index 59740aac..2c93fbc6 100644
--- a/ethosu/vela/errors.py
+++ b/ethosu/vela/errors.py
@@ -15,7 +15,6 @@
# limitations under the License.
# Description:
# Defines custom exceptions.
-import inspect
import sys
from .operation import Operation
@@ -122,20 +121,3 @@ def TensorError(tens, msg):
print("Error: {}".format(data))
sys.exit(1)
-
-
-def typecheck(func):
- def wrapper(*args, **kwargs):
- fsig = inspect.signature(func)
- args_zipped = zip(kwargs.values(), fsig.parameters.keys())
- for actual, expected in args_zipped:
- expected_type = fsig.parameters[expected].annotation
- actual_type = type(actual)
- if expected_type is inspect.Parameter.empty:
- raise TypeError("Please provide type info for {}, hint = {}".format(expected, actual_type))
- if expected_type is not actual_type:
- raise TypeError("expected : {}, but got {}".format(expected_type, actual_type))
- # Actual execution
- return func(*args, **kwargs)
-
- return wrapper
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index 7e03a94d..45427a1a 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -23,7 +23,6 @@ import numpy as np
from .architecture_features import Accelerator
from .architecture_features import ArchitectureFeatures
from .data_type import DataType
-from .errors import typecheck
from .errors import UnsupportedFeatureError
from .nn_graph import SchedulingStrategy
from .numeric_util import round_up
@@ -45,7 +44,6 @@ WeightCompressionConfig = namedtuple(
)
-@typecheck
def encode_weights(
accelerator: Accelerator,
weights_volume: np.ndarray,
@@ -68,6 +66,15 @@ def encode_weights(
:return: a bytearray of compressed weights
"""
+ # Check arg types
+ assert isinstance(accelerator, Accelerator)
+ assert isinstance(weights_volume, np.ndarray)
+ assert isinstance(dilation_xy, tuple)
+ assert isinstance(ifm_bitdepth, int)
+ assert isinstance(ofm_block_depth, int)
+ assert isinstance(is_depthwise, bool)
+ assert isinstance(is_partkernel, bool)
+
# Checks for weight layout
assert len(weights_volume.shape) == 4, "weights ndarray should have a shape of 4"
@@ -94,7 +101,6 @@ def encode_weights(
return encoded_stream
-@typecheck
def encode_bias(bias: np.int64, scale: int, shift: int):
"""
Public facing API to pack bias and scale values as required by the hardware
@@ -103,6 +109,11 @@ def encode_bias(bias: np.int64, scale: int, shift: int):
:param shift: 6bit shift value
:return: packed 80bit [0(2-bits),shift(6-bits),scale(32-bits),bias(40-bits)]
"""
+ # Check arg types
+ assert isinstance(bias, np.int64)
+ assert isinstance(scale, int)
+ assert isinstance(shift, int)
+
assert -(1 << (40 - 1)) <= bias < (1 << (40 - 1)) # signed 40-bit range
assert 0 <= scale < (1 << 32) # unsigned 32-bit range
assert 0 <= shift < (1 << 6) # unsigned 6-bit range