# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 import os import pytest import tensorflow as tf import tempfile import tosa_checker @pytest.fixture(scope="module") def build_tosa_non_compat_model(): num_boxes = 6 max_output_size = 5 iou_threshold = 0.5 score_threshold = 0.1 def non_max_suppression(x): boxes = x[0] scores = x[1] output = tf.image.non_max_suppression_with_scores( boxes[0], scores[0], max_output_size=max_output_size, iou_threshold=iou_threshold, score_threshold=score_threshold, soft_nms_sigma=1.0, ) return output boxes_in = tf.keras.layers.Input( shape=(num_boxes, 4), batch_size=1, dtype=tf.float32, name="boxes" ) scores_in = tf.keras.layers.Input( shape=(num_boxes), batch_size=1, dtype=tf.float32, name="scores" ) outputs = tf.keras.layers.Lambda(non_max_suppression, name="nms")( [boxes_in, scores_in] ) model = tf.keras.models.Model(inputs=[boxes_in, scores_in], outputs=outputs) return model @pytest.fixture(scope="module") def build_tosa_non_compat_model_custom_op(): @tf.function( experimental_implements='name: "exp_log" \ attr { \ key: "tfl_fusable_op" \ value { b: true } \ }' ) def exp_log(x): x = tf.math.exp(x) x = tf.math.log(x) return x input = tf.keras.layers.Input(shape=(16,), name="input") x = tf.keras.layers.Lambda(exp_log, name="exp_log")(input) x = tf.keras.layers.Dense(8, activation="relu", name="dense")(x) model = tf.keras.models.Model(inputs=[input], outputs=x) return model @pytest.fixture(scope="module") def build_tosa_compat_model(): input = tf.keras.layers.Input(shape=(16,), name="input") x = tf.keras.layers.Dense(8, activation="relu", name="dense")(input) model = tf.keras.models.Model(inputs=[input], outputs=x) return model def create_tflite(model): converter = tf.lite.TFLiteConverter.from_keras_model(model) tflite_model = converter.convert() return tflite_model @pytest.fixture(scope="module") def non_compat_file(build_tosa_non_compat_model): tflite_model = create_tflite(build_tosa_non_compat_model) with tempfile.TemporaryDirectory() as tmp_dir: file = os.path.join(tmp_dir, "test.tflite") open(file, "wb").write(tflite_model) yield file @pytest.fixture(scope="module") def non_compat_file_custom_op(build_tosa_non_compat_model_custom_op): tflite_model = create_tflite(build_tosa_non_compat_model_custom_op) with tempfile.TemporaryDirectory() as tmp_dir: file = os.path.join(tmp_dir, "test.tflite") open(file, "wb").write(tflite_model) yield file @pytest.fixture(scope="module") def compat_file(build_tosa_compat_model): tflite_model = create_tflite(build_tosa_compat_model) with tempfile.TemporaryDirectory() as tmp_dir: file = os.path.join(tmp_dir, "test.tflite") open(file, "wb").write(tflite_model) yield file class TestTosaCompatibilityTool: def test_bad_tflite_file(self): make_bad_tfile = os.path.join(tempfile.mkdtemp(), "test.tflite") open(make_bad_tfile, "wb").write("bad tflite file".encode("ASCII")) with pytest.raises(RuntimeError): checker = tosa_checker.TOSAChecker(model_path=make_bad_tfile) def test_tosa_non_compat_model(self, non_compat_file): checker = tosa_checker.TOSAChecker(model_path=non_compat_file) tosa_compatible = checker.is_tosa_compatible() assert tosa_compatible == False ops = checker._get_tosa_compatibility_for_ops() assert type(ops) == list assert [[op.name, op.is_tosa_compatible] for op in ops] == [ ["tfl.pseudo_const", True], ["tfl.pseudo_const", True], ["tfl.pseudo_const", True], ["tfl.strided_slice", True], ["tfl.pseudo_const", True], ["tfl.pseudo_const", True], ["tfl.pseudo_const", True], ["tfl.strided_slice", True], ["tfl.pseudo_const", True], ["tfl.pseudo_const", True], ["tfl.pseudo_const", True], ["tfl.pseudo_const", True], ["tfl.non_max_suppression_v5", False], ] tosa_ops = checker._get_used_tosa_ops() assert type(tosa_ops) == list assert [[op.name, op.is_tosa_compatible] for op in tosa_ops] == [ ["tosa.const", True], ["tosa.const", True], ["tosa.const", True], ["tosa.const", True], ["tosa.reshape", True], ["tosa.reshape", True], ] def test_tosa_non_compat_model_with_custom_op(self, non_compat_file_custom_op): checker = tosa_checker.TOSAChecker(model_path=non_compat_file_custom_op) tosa_compatible = checker.is_tosa_compatible() assert tosa_compatible == False ops = checker._get_tosa_compatibility_for_ops() assert type(ops) == list assert [[op.name, op.is_tosa_compatible] for op in ops] == [ ["tfl.custom", False], ["tfl.pseudo_const", True], ["tfl.no_value", True], ["tfl.fully_connected", True], ] tosa_ops = checker._get_used_tosa_ops() assert type(tosa_ops) == list assert [[op.name, op.is_tosa_compatible] for op in tosa_ops] == [ ["tosa.const", True], ["tosa.const", True], ["tosa.custom", False], ["tosa.fully_connected", True], ["tosa.clamp", True], ] def test_tosa_compat_model(self, compat_file): checker = tosa_checker.TOSAChecker(model_path=compat_file) tosa_compatible = checker.is_tosa_compatible() assert tosa_compatible == True ops = checker._get_tosa_compatibility_for_ops() assert type(ops) == list assert [[op.name, op.is_tosa_compatible] for op in ops] == [ ["tfl.pseudo_const", True], ["tfl.no_value", True], ["tfl.fully_connected", True], ] tosa_ops = checker._get_used_tosa_ops() assert type(tosa_ops) == list assert [[op.name, op.is_tosa_compatible] for op in tosa_ops] == [ ["tosa.const", True], ["tosa.const", True], ["tosa.fully_connected", True], ["tosa.clamp", True], ] def test_tosa_non_compat_model_mlir_representation(self, non_compat_file): checker = tosa_checker.TOSAChecker(model_path=non_compat_file) tfl_mlir_representation = checker._get_mlir_model_representation( elide_large_elements_attrs=True ) assert "non_max_suppression_v5" in tfl_mlir_representation tosa_mlir_representation = checker._get_mlir_tosa_model_representation( elide_large_elements_attrs=True ) assert "non_max_suppression_v5" in tosa_mlir_representation def test_tosa_compat_model_mlir_representation(self, compat_file): checker = tosa_checker.TOSAChecker(model_path=compat_file) tfl_mlir_representation = checker._get_mlir_model_representation( elide_large_elements_attrs=True ) assert "fully_connected" in tfl_mlir_representation tosa_mlir_representation = checker._get_mlir_tosa_model_representation( elide_large_elements_attrs=True ) assert "fully_connected" in tosa_mlir_representation