aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_model_semantic.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_model_semantic.py')
-rw-r--r--ethosu/vela/tflite_model_semantic.py8
1 files changed, 8 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index 495d71a6..5661f36e 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -191,6 +191,7 @@ class TFLiteSemantic:
# ArgMax specific checks:
self.specific_constraints[Op.ArgMax].append(TFLiteSemantic.constraint_input_8bit)
+ self.specific_constraints[Op.ArgMax].append(TFLiteSemantic.constraint_argmax_output)
def is_operator_semantic_valid(self, op):
ext_type = optype_to_builtintype(op.type)
@@ -634,6 +635,13 @@ class TFLiteSemantic:
return valid, f"Op has ifm_dtype={ifm_dtype}"
@staticmethod
+ def constraint_argmax_output(op):
+ "OFM must be int32 or int64"
+ ofm_dtype = op.ofm.dtype
+ valid = ofm_dtype in (DataType.int32, DataType.int64)
+ return valid, f"Op has ofm_dtype={ofm_dtype}"
+
+ @staticmethod
def constraint_matching_either_shapes(op):
"At least one Input's shape must match the OFM's shape"
ifm_shape = op.ifm.shape