From e51a05c4ab4493a6745dd15d6d6a41d0f1663552 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johan=20Alfv=C3=A9n?= Date: Wed, 11 May 2022 13:10:50 +0200 Subject: MLBEDSW-6454: Enable ReLu with negative alpha value Removing constraint for negative alpha value in ReLu for int8 and uint8. Signed-off-by: Johan Alfven Change-Id: Id7a3a30bf5d1f0a591f990bd04cd0dbbad5819c6 --- ethosu/vela/test/test_tflite_model_semantic.py | 9 +++++++-- ethosu/vela/tflite_model_semantic.py | 7 ++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py index 2d6ca15a..e290dd2c 100644 --- a/ethosu/vela/test/test_tflite_model_semantic.py +++ b/ethosu/vela/test/test_tflite_model_semantic.py @@ -413,12 +413,17 @@ def test_constraint_matching_either_shapes(): def test_constraint_alpha_valid(): - # Alpha cannot be negative - op = testutil.create_elemwise_op(Op.LeakyRelu, "op", [2, 2], None, [2, 2]) + # Alpha can only be negative for int8 and uint8 + op = testutil.create_elemwise_op(Op.LeakyRelu, "op", [2, 2], None, [2, 2], DataType.int16) op.attrs["alpha"] = 0 assert semantic_checker.is_operator_semantic_valid(op) op.attrs["alpha"] = -1 assert not semantic_checker.is_operator_semantic_valid(op) + op = testutil.create_elemwise_op(Op.LeakyRelu, "op", [2, 2], None, [2, 2], DataType.int8) + op.attrs["alpha"] = 0 + assert semantic_checker.is_operator_semantic_valid(op) + op.attrs["alpha"] = -1 + assert semantic_checker.is_operator_semantic_valid(op) def test_constraint_hardswish_dtype(): diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py index c811a0d4..e0541df5 100644 --- a/ethosu/vela/tflite_model_semantic.py +++ b/ethosu/vela/tflite_model_semantic.py @@ -532,10 +532,11 @@ class TFLiteSemantic: @staticmethod def constraint_alpha_valid(op): - "Alpha must not be negative" + "Alpha only allowed to be negative if IFM is int8 or uint8" alpha = op.attrs["alpha"] - valid = alpha >= 0 - return valid, f"Op has alpha={alpha}" + ifm_dtype = op.ifm.dtype + valid = ifm_dtype == DataType.int8 or ifm_dtype == DataType.uint8 or alpha >= 0 + return valid, f"Op has alpha={alpha} and ifm_dtype={ifm_dtype} " @staticmethod def constraint_keep_dim_ifm_ofm(op): -- cgit v1.2.1