aboutsummaryrefslogtreecommitdiff
path: root/tests/test_tosa_checker.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_tosa_checker.py')
-rw-r--r--tests/test_tosa_checker.py56
1 files changed, 56 insertions, 0 deletions
diff --git a/tests/test_tosa_checker.py b/tests/test_tosa_checker.py
index eb49e65..b24073a 100644
--- a/tests/test_tosa_checker.py
+++ b/tests/test_tosa_checker.py
@@ -40,6 +40,29 @@ def build_tosa_non_compat_model():
@pytest.fixture(scope="module")
+def build_tosa_non_compat_model_custom_op():
+ @tf.function(
+ experimental_implements='name: "exp_log" \
+ attr { \
+ key: "tfl_fusable_op" \
+ value { b: true } \
+ }'
+ )
+ def exp_log(x):
+ x = tf.math.exp(x)
+ x = tf.math.log(x)
+
+ return x
+
+ input = tf.keras.layers.Input(shape=(16,))
+ x = tf.keras.layers.Lambda(exp_log)(input)
+ x = tf.keras.layers.Dense(8, activation="relu")(x)
+ model = tf.keras.models.Model(inputs=[input], outputs=x)
+
+ return model
+
+
+@pytest.fixture(scope="module")
def build_tosa_compat_model():
input = tf.keras.layers.Input(shape=(16,))
x = tf.keras.layers.Dense(8, activation="relu")(input)
@@ -63,6 +86,15 @@ def non_compat_file(build_tosa_non_compat_model):
@pytest.fixture(scope="module")
+def non_compat_file_custom_op(build_tosa_non_compat_model_custom_op):
+ tflite_model = create_tflite(build_tosa_non_compat_model_custom_op)
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ file = os.path.join(tmp_dir, "test.tflite")
+ open(file, "wb").write(tflite_model)
+ yield file
+
+
+@pytest.fixture(scope="module")
def compat_file(build_tosa_compat_model):
tflite_model = create_tflite(build_tosa_compat_model)
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -113,6 +145,30 @@ class TestTosaCompatibilityTool:
["tosa.reshape", True],
]
+ def test_tosa_non_compat_model_with_custom_op(self, non_compat_file_custom_op):
+ checker = tosa_checker.TOSAChecker(model_path=non_compat_file_custom_op)
+ tosa_compatible = checker.is_tosa_compatible()
+ assert tosa_compatible == False
+
+ ops = checker._get_tosa_compatibility_for_ops()
+ assert type(ops) == list
+ assert [[op.name, op.is_tosa_compatible] for op in ops] == [
+ ["tfl.custom", False],
+ ["tfl.pseudo_const", True],
+ ["tfl.no_value", True],
+ ["tfl.fully_connected", True],
+ ]
+
+ tosa_ops = checker._get_used_tosa_ops()
+ assert type(tosa_ops) == list
+ assert [[op.name, op.is_tosa_compatible] for op in tosa_ops] == [
+ ["tosa.const", True],
+ ["tosa.const", True],
+ ["tosa.custom", False],
+ ["tosa.fully_connected", True],
+ ["tosa.clamp", True],
+ ]
+
def test_tosa_compat_model(self, compat_file):
checker = tosa_checker.TOSAChecker(model_path=compat_file)
tosa_compatible = checker.is_tosa_compatible()