diff options
author | Tim Hall <tim.hall@arm.com> | 2020-09-09 21:58:15 +0100 |
---|---|---|
committer | patrik.gustavsson <patrik.gustavsson@arm.com> | 2020-10-21 12:13:37 +0000 |
commit | 9358296a51b9186335304a53bd7ea5dfbe5322d8 (patch) | |
tree | a30b8f4e092eb78984c9f15beeaabff4d36c3002 /ethosu/vela/supported_operators.py | |
parent | e8887a3e6ed6638b06ecac9581deaaa89b8059c0 (diff) | |
download | ethos-u-vela-9358296a51b9186335304a53bd7ea5dfbe5322d8.tar.gz |
vela: Improve the scaling is equal check
- Fixed and documented both tensor and quant params scaling checks
- Added quant params validity check and tensor quantisation check
- Added valid tensor checks to some graph optimisation functions
Signed-off-by: Tim Hall <tim.hall@arm.com>
Change-Id: I8d6e8f03a603d28886dde511672c8399c85b794c
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r-- | ethosu/vela/supported_operators.py | 21 |
1 files changed, 14 insertions, 7 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index f4dd5796..dfb7bc7d 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -24,6 +24,8 @@ from .data_type import DataType from .numeric_util import is_integer from .operation import get_slice_offsets from .operation import Op +from .tensor import check_quantized_tens_scaling_equal +from .tensor import check_tens_quantized # Custom decorator function to allow formatting docstrings containing "{}" @@ -730,17 +732,22 @@ class SupportedOperators: @classmethod def check_quantization_restrictions_binary_elem_wise(cls, op): - # makes sure IFM1, IFM2 and OFM quantization are equal for binary ops + # checks that IFM1, IFM2 and OFM quantization are equal for binary ops + assert len(op.inputs) >= 2 and len(op.outputs) == 1 if ( - op.inputs[0].quantization is None - or not op.inputs[0].is_scaling_equal(op.inputs[1]) - or not op.inputs[0].is_scaling_equal(op.outputs[0]) + not check_tens_quantized(op.inputs[0]) + or not check_tens_quantized(op.inputs[1]) + or not check_tens_quantized(op.outputs[0]) ): - print( - "Warning: Input/output tensors with different quantization is unsupported for the", op.type, "operator" - ) + warn_cpu(op, "has non-quantised input and/or output tensors") + return False + + if not check_quantized_tens_scaling_equal(op.inputs[0], op.inputs[1]) or not check_quantized_tens_scaling_equal( + op.inputs[0], op.outputs[0] + ): + warn_cpu(op, "has input/output tensors with different quantisation which is illegal") return False return True |