From 9fb745d92173abfa270e99bd5c9bd7cf85bfeb31 Mon Sep 17 00:00:00 2001 From: Thibaut Goetghebuer-Planchon Date: Wed, 23 Nov 2022 11:42:43 +0000 Subject: Models with tfl.custom ops are considered to be non TOSA compliant Change-Id: Ib31b44b3819bbdf517b96d834879155eebb4f09c --- tests/test_tosa_checker.py | 56 ++++++++++++++++++++++++++++++++++++++++++++ tosa_checker/tosa_checker.cc | 39 +++++++++++++++++++++--------- tosa_checker/tosa_checker.h | 2 ++ 3 files changed, 86 insertions(+), 11 deletions(-) diff --git a/tests/test_tosa_checker.py b/tests/test_tosa_checker.py index eb49e65..b24073a 100644 --- a/tests/test_tosa_checker.py +++ b/tests/test_tosa_checker.py @@ -39,6 +39,29 @@ def build_tosa_non_compat_model(): return model +@pytest.fixture(scope="module") +def build_tosa_non_compat_model_custom_op(): + @tf.function( + experimental_implements='name: "exp_log" \ + attr { \ + key: "tfl_fusable_op" \ + value { b: true } \ + }' + ) + def exp_log(x): + x = tf.math.exp(x) + x = tf.math.log(x) + + return x + + input = tf.keras.layers.Input(shape=(16,)) + x = tf.keras.layers.Lambda(exp_log)(input) + x = tf.keras.layers.Dense(8, activation="relu")(x) + model = tf.keras.models.Model(inputs=[input], outputs=x) + + return model + + @pytest.fixture(scope="module") def build_tosa_compat_model(): input = tf.keras.layers.Input(shape=(16,)) @@ -62,6 +85,15 @@ def non_compat_file(build_tosa_non_compat_model): yield file +@pytest.fixture(scope="module") +def non_compat_file_custom_op(build_tosa_non_compat_model_custom_op): + tflite_model = create_tflite(build_tosa_non_compat_model_custom_op) + with tempfile.TemporaryDirectory() as tmp_dir: + file = os.path.join(tmp_dir, "test.tflite") + open(file, "wb").write(tflite_model) + yield file + + @pytest.fixture(scope="module") def compat_file(build_tosa_compat_model): tflite_model = create_tflite(build_tosa_compat_model) @@ -113,6 +145,30 @@ class TestTosaCompatibilityTool: ["tosa.reshape", True], ] + def test_tosa_non_compat_model_with_custom_op(self, non_compat_file_custom_op): + checker = tosa_checker.TOSAChecker(model_path=non_compat_file_custom_op) + tosa_compatible = checker.is_tosa_compatible() + assert tosa_compatible == False + + ops = checker._get_tosa_compatibility_for_ops() + assert type(ops) == list + assert [[op.name, op.is_tosa_compatible] for op in ops] == [ + ["tfl.custom", False], + ["tfl.pseudo_const", True], + ["tfl.no_value", True], + ["tfl.fully_connected", True], + ] + + tosa_ops = checker._get_used_tosa_ops() + assert type(tosa_ops) == list + assert [[op.name, op.is_tosa_compatible] for op in tosa_ops] == [ + ["tosa.const", True], + ["tosa.const", True], + ["tosa.custom", False], + ["tosa.fully_connected", True], + ["tosa.clamp", True], + ] + def test_tosa_compat_model(self, compat_file): checker = tosa_checker.TOSAChecker(model_path=compat_file) tosa_compatible = checker.is_tosa_compatible() diff --git a/tosa_checker/tosa_checker.cc b/tosa_checker/tosa_checker.cc index 714cab3..d42b826 100644 --- a/tosa_checker/tosa_checker.cc +++ b/tosa_checker/tosa_checker.cc @@ -47,10 +47,10 @@ bool TOSAChecker::IsTOSACompatible() { bool is_tosa_compatible = true; for (auto func : m_tosa_model->getOps()) { func.walk([&](mlir::Operation *op) { - // Ignore func namespace - const mlir::Dialect *dialect = op->getDialect(); - if (!dialect || (!dialect->getNamespace().equals("tosa") && - !dialect->getNamespace().equals("func"))) { + // Ignore func dialect ops + const bool is_func = + op->getDialect() && op->getDialect()->getNamespace().equals("func"); + if (!is_func && !IsTOSACompatibleOp(*op)) { is_tosa_compatible = false; return mlir::WalkResult::interrupt(); } @@ -69,10 +69,10 @@ std::vector TOSAChecker::GetTOSACompatibilityForOps( std::unordered_set tosa_incompatible_locs; for (auto func : m_tosa_model->getOps()) { func.walk([&](mlir::Operation *op) { - // Ignore func namespace - const mlir::Dialect *dialect = op->getDialect(); - if (!dialect || (!dialect->getNamespace().equals("tosa") && - !dialect->getNamespace().equals("func"))) { + // Ignore func dialect ops + const bool is_func = + op->getDialect() && op->getDialect()->getNamespace().equals("func"); + if (!is_func && !IsTOSACompatibleOp(*op)) { tosa_incompatible_locs.insert(op->getLoc()); } }); @@ -85,8 +85,9 @@ std::vector TOSAChecker::GetTOSACompatibilityForOps( for (auto func : m_model->getOps()) { func.walk([&](mlir::Operation *op) { // Ignore func namespace - const mlir::Dialect *dialect = op->getDialect(); - if (!dialect || !dialect->getNamespace().equals("func")) { + const bool is_func = + op->getDialect() && op->getDialect()->getNamespace().equals("func"); + if (!is_func) { const bool is_tosa_compatible = tosa_incompatible_locs.find(op->getLoc()) == tosa_incompatible_locs.end(); @@ -102,7 +103,7 @@ std::vector TOSAChecker::GetUsedTOSAOps( bool elide_large_attrs) { std::vector tosa_ops; for (mlir::Operation *op : GetTOSAOps(*m_tosa_model)) { - const bool is_tosa_compatible = true; + const bool is_tosa_compatible = IsTOSACompatibleOp(*op); tosa_ops.push_back(ToOperator(*op, is_tosa_compatible, elide_large_attrs)); } @@ -118,6 +119,22 @@ std::string TOSAChecker::GetMLIRTOSAModelRepresentation( return GetMLIRRepresentation(*m_tosa_model, elide_large_attrs); } +bool TOSAChecker::IsTOSACompatibleOp(mlir::Operation &op) { + const mlir::Dialect *dialect = op.getDialect(); + if (dialect && dialect->getNamespace().equals("tosa")) { + // Due to the opaque nature of the tosa.custom operator, a TOSA compliant + // system may not be able to run a model with such operators. We + // consider these models as TOSA incompatible. + if (op.getName().getStringRef().equals("tosa.custom")) { + return false; + } + + return true; + } + + return false; +} + template std::string TOSAChecker::GetMLIRRepresentation(T &&op) { std::string value; diff --git a/tosa_checker/tosa_checker.h b/tosa_checker/tosa_checker.h index d7750ea..f0e473a 100644 --- a/tosa_checker/tosa_checker.h +++ b/tosa_checker/tosa_checker.h @@ -47,6 +47,8 @@ class TOSAChecker { std::string GetMLIRTOSAModelRepresentation(bool elide_large_attrs); private: + static bool IsTOSACompatibleOp(mlir::Operation& op); + template static std::string GetMLIRRepresentation(T&& op); -- cgit v1.2.1