diff options
Diffstat (limited to 'ethosu/vela/tflite_model_semantic.py')
-rw-r--r-- | ethosu/vela/tflite_model_semantic.py | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py index 495d71a6..5661f36e 100644 --- a/ethosu/vela/tflite_model_semantic.py +++ b/ethosu/vela/tflite_model_semantic.py @@ -191,6 +191,7 @@ class TFLiteSemantic: # ArgMax specific checks: self.specific_constraints[Op.ArgMax].append(TFLiteSemantic.constraint_input_8bit) + self.specific_constraints[Op.ArgMax].append(TFLiteSemantic.constraint_argmax_output) def is_operator_semantic_valid(self, op): ext_type = optype_to_builtintype(op.type) @@ -634,6 +635,13 @@ class TFLiteSemantic: return valid, f"Op has ifm_dtype={ifm_dtype}" @staticmethod + def constraint_argmax_output(op): + "OFM must be int32 or int64" + ofm_dtype = op.ofm.dtype + valid = ofm_dtype in (DataType.int32, DataType.int64) + return valid, f"Op has ofm_dtype={ofm_dtype}" + + @staticmethod def constraint_matching_either_shapes(op): "At least one Input's shape must match the OFM's shape" ifm_shape = op.ifm.shape |