diff options
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r-- | ethosu/vela/supported_operators.py | 21 |
1 files changed, 14 insertions, 7 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index f4dd5796..dfb7bc7d 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -24,6 +24,8 @@ from .data_type import DataType from .numeric_util import is_integer from .operation import get_slice_offsets from .operation import Op +from .tensor import check_quantized_tens_scaling_equal +from .tensor import check_tens_quantized # Custom decorator function to allow formatting docstrings containing "{}" @@ -730,17 +732,22 @@ class SupportedOperators: @classmethod def check_quantization_restrictions_binary_elem_wise(cls, op): - # makes sure IFM1, IFM2 and OFM quantization are equal for binary ops + # checks that IFM1, IFM2 and OFM quantization are equal for binary ops + assert len(op.inputs) >= 2 and len(op.outputs) == 1 if ( - op.inputs[0].quantization is None - or not op.inputs[0].is_scaling_equal(op.inputs[1]) - or not op.inputs[0].is_scaling_equal(op.outputs[0]) + not check_tens_quantized(op.inputs[0]) + or not check_tens_quantized(op.inputs[1]) + or not check_tens_quantized(op.outputs[0]) ): - print( - "Warning: Input/output tensors with different quantization is unsupported for the", op.type, "operator" - ) + warn_cpu(op, "has non-quantised input and/or output tensors") + return False + + if not check_quantized_tens_scaling_equal(op.inputs[0], op.inputs[1]) or not check_quantized_tens_scaling_equal( + op.inputs[0], op.outputs[0] + ): + warn_cpu(op, "has input/output tensors with different quantisation which is illegal") return False return True |