aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_model_semantic.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_model_semantic.py')
-rw-r--r--ethosu/vela/tflite_model_semantic.py10
1 files changed, 9 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index 9f53a1e6..495d71a6 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -77,7 +77,9 @@ class TFLiteSemantic:
)
binary_elem_wise_main_ops = binary_elem_wise_min_max_ops | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
- shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV, Op.Mean, Op.ExpandDims, Op.Quantize))
+ shapeless_input_ops = binary_elem_wise_main_ops | set(
+ (Op.Split, Op.SplitV, Op.Mean, Op.ExpandDims, Op.Quantize, Op.ArgMax)
+ )
reshape_ops = set(
(
Op.Reshape,
@@ -187,6 +189,9 @@ class TFLiteSemantic:
self.specific_constraints[Op.Mean].append(TFLiteSemantic.constraint_mean_input_dims)
self.specific_constraints[Op.Mean].append(TFLiteSemantic.constraint_mean_axis)
+ # ArgMax specific checks:
+ self.specific_constraints[Op.ArgMax].append(TFLiteSemantic.constraint_input_8bit)
+
def is_operator_semantic_valid(self, op):
ext_type = optype_to_builtintype(op.type)
@@ -226,6 +231,9 @@ class TFLiteSemantic:
TFLiteSemantic.constraint_tens_no_dynamic,
TFLiteSemantic.constraint_tens_output_scalar,
],
+ Op.ArgMax: [
+ TFLiteSemantic.constraint_tens_quant_none_check,
+ ],
}
return generic_constraints_exclude_list