diff options
Diffstat (limited to 'ethosu/vela/tflite_model_semantic.py')
-rw-r--r-- | ethosu/vela/tflite_model_semantic.py | 15 |
1 files changed, 14 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py index c8b373a3..6e2467bb 100644 --- a/ethosu/vela/tflite_model_semantic.py +++ b/ethosu/vela/tflite_model_semantic.py @@ -26,6 +26,7 @@ from .operation import get_slice_offsets from .operation import Op from .supported_operators_util import docstring_format_args from .supported_operators_util import list_formatter +from .tensor import check_quantized_tens_scaling_equal from .tflite_mapping import BUILTIN_OPERATOR_UNKNOWN from .tflite_mapping import optype_to_builtintype @@ -53,7 +54,8 @@ class TFLiteSemantic: binary_elem_wise_add_mul_sub = set((Op.Add, Op.Mul, Op.Sub,)) binary_elem_wise_main_ops = binary_elem_wise_min_max_ops | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops - shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV, Op.Mean)) + shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV, Op.Mean, Op.ExpandDims)) + reshape_ops = set((Op.Reshape, Op.QuantizedReshape, Op.Squeeze, Op.ExpandDims,)) def __init__(self): # Setup the generic constraints. Note: the order matters @@ -110,6 +112,10 @@ class TFLiteSemantic: self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_signed) self.specific_constraints[op_type].append(TFLiteSemantic.constraint_unsigned_valid) + # Ops reshaping dimensions: Reshape, Squeeze and ExpandDims + for op_type in TFLiteSemantic.reshape_ops: + self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_quant) + # Softmax specific checks: self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_matching_shapes) self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_matching_in_out_types) @@ -518,6 +524,13 @@ class TFLiteSemantic: valid = axis in (1, 2, [1], [2], [1, 2], [2, 1]) return valid, f"Axis is {axis}" + @staticmethod + def constraint_matching_in_out_quant(op): + "Input and output quantisation must match." + if not check_quantized_tens_scaling_equal(op.ifm, op.ofm): + return False, "IFM and OFM quantisation parameters are not equal." + return True, "IFM and OFM quantisation parameters matches." + def tflite_semantic_checker(nng): semantic_checker = TFLiteSemantic() |