aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_model_semantic.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_model_semantic.py')
-rw-r--r--ethosu/vela/tflite_model_semantic.py23
1 files changed, 22 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index e0541df5..ee66d4cc 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -186,7 +186,15 @@ class TFLiteSemantic:
if op.type in (Op.Placeholder, Op.SubgraphInput, Op.Const):
return True
- for constraint in self.generic_constraints + self.specific_constraints[op.type]:
+ # Generic constraints list filtered out to exclude certain constraints depending on op.type
+ filtered_generic_constraints = []
+
+ for constraint in self.generic_constraints:
+ # Check constraint not in dictionary otherwise return empty array
+ if constraint not in self.get_generic_constraint_exclude_list().get(op.type, []):
+ filtered_generic_constraints.append(constraint)
+
+ for constraint in filtered_generic_constraints + self.specific_constraints[op.type]:
valid, extra = constraint(op)
if not valid:
print(
@@ -200,6 +208,19 @@ class TFLiteSemantic:
return True
@staticmethod
+ def get_generic_constraint_exclude_list():
+
+ # Not all generic constraints can be applied to each operator
+ generic_constraints_exclude_list = {
+ Op.Shape: [
+ TFLiteSemantic.constraint_tens_quant_none_check,
+ TFLiteSemantic.constraint_tens_quant_scale,
+ TFLiteSemantic.constraint_quant_scale_inf,
+ ]
+ }
+ return generic_constraints_exclude_list
+
+ @staticmethod
def constraint_none_const_tensors(op):
"Constant tensors should not have NoneType-values"
valid = True