From 9fb745d92173abfa270e99bd5c9bd7cf85bfeb31 Mon Sep 17 00:00:00 2001 From: Thibaut Goetghebuer-Planchon Date: Wed, 23 Nov 2022 11:42:43 +0000 Subject: Models with tfl.custom ops are considered to be non TOSA compliant Change-Id: Ib31b44b3819bbdf517b96d834879155eebb4f09c --- tests/test_tosa_checker.py | 56 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) (limited to 'tests') diff --git a/tests/test_tosa_checker.py b/tests/test_tosa_checker.py index eb49e65..b24073a 100644 --- a/tests/test_tosa_checker.py +++ b/tests/test_tosa_checker.py @@ -39,6 +39,29 @@ def build_tosa_non_compat_model(): return model +@pytest.fixture(scope="module") +def build_tosa_non_compat_model_custom_op(): + @tf.function( + experimental_implements='name: "exp_log" \ + attr { \ + key: "tfl_fusable_op" \ + value { b: true } \ + }' + ) + def exp_log(x): + x = tf.math.exp(x) + x = tf.math.log(x) + + return x + + input = tf.keras.layers.Input(shape=(16,)) + x = tf.keras.layers.Lambda(exp_log)(input) + x = tf.keras.layers.Dense(8, activation="relu")(x) + model = tf.keras.models.Model(inputs=[input], outputs=x) + + return model + + @pytest.fixture(scope="module") def build_tosa_compat_model(): input = tf.keras.layers.Input(shape=(16,)) @@ -62,6 +85,15 @@ def non_compat_file(build_tosa_non_compat_model): yield file +@pytest.fixture(scope="module") +def non_compat_file_custom_op(build_tosa_non_compat_model_custom_op): + tflite_model = create_tflite(build_tosa_non_compat_model_custom_op) + with tempfile.TemporaryDirectory() as tmp_dir: + file = os.path.join(tmp_dir, "test.tflite") + open(file, "wb").write(tflite_model) + yield file + + @pytest.fixture(scope="module") def compat_file(build_tosa_compat_model): tflite_model = create_tflite(build_tosa_compat_model) @@ -113,6 +145,30 @@ class TestTosaCompatibilityTool: ["tosa.reshape", True], ] + def test_tosa_non_compat_model_with_custom_op(self, non_compat_file_custom_op): + checker = tosa_checker.TOSAChecker(model_path=non_compat_file_custom_op) + tosa_compatible = checker.is_tosa_compatible() + assert tosa_compatible == False + + ops = checker._get_tosa_compatibility_for_ops() + assert type(ops) == list + assert [[op.name, op.is_tosa_compatible] for op in ops] == [ + ["tfl.custom", False], + ["tfl.pseudo_const", True], + ["tfl.no_value", True], + ["tfl.fully_connected", True], + ] + + tosa_ops = checker._get_used_tosa_ops() + assert type(tosa_ops) == list + assert [[op.name, op.is_tosa_compatible] for op in tosa_ops] == [ + ["tosa.const", True], + ["tosa.const", True], + ["tosa.custom", False], + ["tosa.fully_connected", True], + ["tosa.clamp", True], + ] + def test_tosa_compat_model(self, compat_file): checker = tosa_checker.TOSAChecker(model_path=compat_file) tosa_compatible = checker.is_tosa_compatible() -- cgit v1.2.1