aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/supported_operators.py
diff options
context:
space:
mode:
authorTim Hall <tim.hall@arm.com>2020-09-09 21:58:15 +0100
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2020-10-21 12:13:37 +0000
commit9358296a51b9186335304a53bd7ea5dfbe5322d8 (patch)
treea30b8f4e092eb78984c9f15beeaabff4d36c3002 /ethosu/vela/supported_operators.py
parente8887a3e6ed6638b06ecac9581deaaa89b8059c0 (diff)
downloadethos-u-vela-9358296a51b9186335304a53bd7ea5dfbe5322d8.tar.gz
vela: Improve the scaling is equal check
- Fixed and documented both tensor and quant params scaling checks - Added quant params validity check and tensor quantisation check - Added valid tensor checks to some graph optimisation functions Signed-off-by: Tim Hall <tim.hall@arm.com> Change-Id: I8d6e8f03a603d28886dde511672c8399c85b794c
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r--ethosu/vela/supported_operators.py21
1 files changed, 14 insertions, 7 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index f4dd5796..dfb7bc7d 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -24,6 +24,8 @@ from .data_type import DataType
from .numeric_util import is_integer
from .operation import get_slice_offsets
from .operation import Op
+from .tensor import check_quantized_tens_scaling_equal
+from .tensor import check_tens_quantized
# Custom decorator function to allow formatting docstrings containing "{}"
@@ -730,17 +732,22 @@ class SupportedOperators:
@classmethod
def check_quantization_restrictions_binary_elem_wise(cls, op):
- # makes sure IFM1, IFM2 and OFM quantization are equal for binary ops
+ # checks that IFM1, IFM2 and OFM quantization are equal for binary ops
+
assert len(op.inputs) >= 2 and len(op.outputs) == 1
if (
- op.inputs[0].quantization is None
- or not op.inputs[0].is_scaling_equal(op.inputs[1])
- or not op.inputs[0].is_scaling_equal(op.outputs[0])
+ not check_tens_quantized(op.inputs[0])
+ or not check_tens_quantized(op.inputs[1])
+ or not check_tens_quantized(op.outputs[0])
):
- print(
- "Warning: Input/output tensors with different quantization is unsupported for the", op.type, "operator"
- )
+ warn_cpu(op, "has non-quantised input and/or output tensors")
+ return False
+
+ if not check_quantized_tens_scaling_equal(op.inputs[0], op.inputs[1]) or not check_quantized_tens_scaling_equal(
+ op.inputs[0], op.outputs[0]
+ ):
+ warn_cpu(op, "has input/output tensors with different quantisation which is illegal")
return False
return True