aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_model_semantic.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_model_semantic.py')
-rw-r--r--ethosu/vela/tflite_model_semantic.py15
1 files changed, 14 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index c8b373a3..6e2467bb 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -26,6 +26,7 @@ from .operation import get_slice_offsets
from .operation import Op
from .supported_operators_util import docstring_format_args
from .supported_operators_util import list_formatter
+from .tensor import check_quantized_tens_scaling_equal
from .tflite_mapping import BUILTIN_OPERATOR_UNKNOWN
from .tflite_mapping import optype_to_builtintype
@@ -53,7 +54,8 @@ class TFLiteSemantic:
binary_elem_wise_add_mul_sub = set((Op.Add, Op.Mul, Op.Sub,))
binary_elem_wise_main_ops = binary_elem_wise_min_max_ops | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
- shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV, Op.Mean))
+ shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV, Op.Mean, Op.ExpandDims))
+ reshape_ops = set((Op.Reshape, Op.QuantizedReshape, Op.Squeeze, Op.ExpandDims,))
def __init__(self):
# Setup the generic constraints. Note: the order matters
@@ -110,6 +112,10 @@ class TFLiteSemantic:
self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_signed)
self.specific_constraints[op_type].append(TFLiteSemantic.constraint_unsigned_valid)
+ # Ops reshaping dimensions: Reshape, Squeeze and ExpandDims
+ for op_type in TFLiteSemantic.reshape_ops:
+ self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_quant)
+
# Softmax specific checks:
self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_matching_shapes)
self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_matching_in_out_types)
@@ -518,6 +524,13 @@ class TFLiteSemantic:
valid = axis in (1, 2, [1], [2], [1, 2], [2, 1])
return valid, f"Axis is {axis}"
+ @staticmethod
+ def constraint_matching_in_out_quant(op):
+ "Input and output quantisation must match."
+ if not check_quantized_tens_scaling_equal(op.ifm, op.ofm):
+ return False, "IFM and OFM quantisation parameters are not equal."
+ return True, "IFM and OFM quantisation parameters matches."
+
def tflite_semantic_checker(nng):
semantic_checker = TFLiteSemantic()