diff options
Diffstat (limited to 'ethosu/vela/tflite_model_semantic.py')
-rw-r--r-- | ethosu/vela/tflite_model_semantic.py | 23 |
1 files changed, 22 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py index e0541df5..ee66d4cc 100644 --- a/ethosu/vela/tflite_model_semantic.py +++ b/ethosu/vela/tflite_model_semantic.py @@ -186,7 +186,15 @@ class TFLiteSemantic: if op.type in (Op.Placeholder, Op.SubgraphInput, Op.Const): return True - for constraint in self.generic_constraints + self.specific_constraints[op.type]: + # Generic constraints list filtered out to exclude certain constraints depending on op.type + filtered_generic_constraints = [] + + for constraint in self.generic_constraints: + # Check constraint not in dictionary otherwise return empty array + if constraint not in self.get_generic_constraint_exclude_list().get(op.type, []): + filtered_generic_constraints.append(constraint) + + for constraint in filtered_generic_constraints + self.specific_constraints[op.type]: valid, extra = constraint(op) if not valid: print( @@ -200,6 +208,19 @@ class TFLiteSemantic: return True @staticmethod + def get_generic_constraint_exclude_list(): + + # Not all generic constraints can be applied to each operator + generic_constraints_exclude_list = { + Op.Shape: [ + TFLiteSemantic.constraint_tens_quant_none_check, + TFLiteSemantic.constraint_tens_quant_scale, + TFLiteSemantic.constraint_quant_scale_inf, + ] + } + return generic_constraints_exclude_list + + @staticmethod def constraint_none_const_tensors(op): "Constant tensors should not have NoneType-values" valid = True |