diff options
author | Thibaut Goetghebuer-Planchon <thibaut.goetghebuer-planchon@arm.com> | 2022-11-23 11:42:43 +0000 |
---|---|---|
committer | Thibaut Goetghebuer-Planchon <thibaut.goetghebuer-planchon@arm.com> | 2022-12-13 10:53:38 +0000 |
commit | 9fb745d92173abfa270e99bd5c9bd7cf85bfeb31 (patch) | |
tree | 3a430220145a184a940131d6f7b0465dacdbde3b | |
parent | a2bcf5f818699082adfd346eba216d96f14d6e6c (diff) | |
download | tosa_checker-9fb745d92173abfa270e99bd5c9bd7cf85bfeb31.tar.gz |
Models with tfl.custom ops are considered to be non TOSA compliant
Change-Id: Ib31b44b3819bbdf517b96d834879155eebb4f09c
-rw-r--r-- | tests/test_tosa_checker.py | 56 | ||||
-rw-r--r-- | tosa_checker/tosa_checker.cc | 39 | ||||
-rw-r--r-- | tosa_checker/tosa_checker.h | 2 |
3 files changed, 86 insertions, 11 deletions
diff --git a/tests/test_tosa_checker.py b/tests/test_tosa_checker.py index eb49e65..b24073a 100644 --- a/tests/test_tosa_checker.py +++ b/tests/test_tosa_checker.py @@ -40,6 +40,29 @@ def build_tosa_non_compat_model(): @pytest.fixture(scope="module") +def build_tosa_non_compat_model_custom_op(): + @tf.function( + experimental_implements='name: "exp_log" \ + attr { \ + key: "tfl_fusable_op" \ + value { b: true } \ + }' + ) + def exp_log(x): + x = tf.math.exp(x) + x = tf.math.log(x) + + return x + + input = tf.keras.layers.Input(shape=(16,)) + x = tf.keras.layers.Lambda(exp_log)(input) + x = tf.keras.layers.Dense(8, activation="relu")(x) + model = tf.keras.models.Model(inputs=[input], outputs=x) + + return model + + +@pytest.fixture(scope="module") def build_tosa_compat_model(): input = tf.keras.layers.Input(shape=(16,)) x = tf.keras.layers.Dense(8, activation="relu")(input) @@ -63,6 +86,15 @@ def non_compat_file(build_tosa_non_compat_model): @pytest.fixture(scope="module") +def non_compat_file_custom_op(build_tosa_non_compat_model_custom_op): + tflite_model = create_tflite(build_tosa_non_compat_model_custom_op) + with tempfile.TemporaryDirectory() as tmp_dir: + file = os.path.join(tmp_dir, "test.tflite") + open(file, "wb").write(tflite_model) + yield file + + +@pytest.fixture(scope="module") def compat_file(build_tosa_compat_model): tflite_model = create_tflite(build_tosa_compat_model) with tempfile.TemporaryDirectory() as tmp_dir: @@ -113,6 +145,30 @@ class TestTosaCompatibilityTool: ["tosa.reshape", True], ] + def test_tosa_non_compat_model_with_custom_op(self, non_compat_file_custom_op): + checker = tosa_checker.TOSAChecker(model_path=non_compat_file_custom_op) + tosa_compatible = checker.is_tosa_compatible() + assert tosa_compatible == False + + ops = checker._get_tosa_compatibility_for_ops() + assert type(ops) == list + assert [[op.name, op.is_tosa_compatible] for op in ops] == [ + ["tfl.custom", False], + ["tfl.pseudo_const", True], + ["tfl.no_value", True], + ["tfl.fully_connected", True], + ] + + tosa_ops = checker._get_used_tosa_ops() + assert type(tosa_ops) == list + assert [[op.name, op.is_tosa_compatible] for op in tosa_ops] == [ + ["tosa.const", True], + ["tosa.const", True], + ["tosa.custom", False], + ["tosa.fully_connected", True], + ["tosa.clamp", True], + ] + def test_tosa_compat_model(self, compat_file): checker = tosa_checker.TOSAChecker(model_path=compat_file) tosa_compatible = checker.is_tosa_compatible() diff --git a/tosa_checker/tosa_checker.cc b/tosa_checker/tosa_checker.cc index 714cab3..d42b826 100644 --- a/tosa_checker/tosa_checker.cc +++ b/tosa_checker/tosa_checker.cc @@ -47,10 +47,10 @@ bool TOSAChecker::IsTOSACompatible() { bool is_tosa_compatible = true; for (auto func : m_tosa_model->getOps<mlir::func::FuncOp>()) { func.walk([&](mlir::Operation *op) { - // Ignore func namespace - const mlir::Dialect *dialect = op->getDialect(); - if (!dialect || (!dialect->getNamespace().equals("tosa") && - !dialect->getNamespace().equals("func"))) { + // Ignore func dialect ops + const bool is_func = + op->getDialect() && op->getDialect()->getNamespace().equals("func"); + if (!is_func && !IsTOSACompatibleOp(*op)) { is_tosa_compatible = false; return mlir::WalkResult::interrupt(); } @@ -69,10 +69,10 @@ std::vector<TOSAChecker::Operator> TOSAChecker::GetTOSACompatibilityForOps( std::unordered_set<mlir::Location> tosa_incompatible_locs; for (auto func : m_tosa_model->getOps<mlir::func::FuncOp>()) { func.walk([&](mlir::Operation *op) { - // Ignore func namespace - const mlir::Dialect *dialect = op->getDialect(); - if (!dialect || (!dialect->getNamespace().equals("tosa") && - !dialect->getNamespace().equals("func"))) { + // Ignore func dialect ops + const bool is_func = + op->getDialect() && op->getDialect()->getNamespace().equals("func"); + if (!is_func && !IsTOSACompatibleOp(*op)) { tosa_incompatible_locs.insert(op->getLoc()); } }); @@ -85,8 +85,9 @@ std::vector<TOSAChecker::Operator> TOSAChecker::GetTOSACompatibilityForOps( for (auto func : m_model->getOps<mlir::func::FuncOp>()) { func.walk([&](mlir::Operation *op) { // Ignore func namespace - const mlir::Dialect *dialect = op->getDialect(); - if (!dialect || !dialect->getNamespace().equals("func")) { + const bool is_func = + op->getDialect() && op->getDialect()->getNamespace().equals("func"); + if (!is_func) { const bool is_tosa_compatible = tosa_incompatible_locs.find(op->getLoc()) == tosa_incompatible_locs.end(); @@ -102,7 +103,7 @@ std::vector<TOSAChecker::Operator> TOSAChecker::GetUsedTOSAOps( bool elide_large_attrs) { std::vector<Operator> tosa_ops; for (mlir::Operation *op : GetTOSAOps(*m_tosa_model)) { - const bool is_tosa_compatible = true; + const bool is_tosa_compatible = IsTOSACompatibleOp(*op); tosa_ops.push_back(ToOperator(*op, is_tosa_compatible, elide_large_attrs)); } @@ -118,6 +119,22 @@ std::string TOSAChecker::GetMLIRTOSAModelRepresentation( return GetMLIRRepresentation(*m_tosa_model, elide_large_attrs); } +bool TOSAChecker::IsTOSACompatibleOp(mlir::Operation &op) { + const mlir::Dialect *dialect = op.getDialect(); + if (dialect && dialect->getNamespace().equals("tosa")) { + // Due to the opaque nature of the tosa.custom operator, a TOSA compliant + // system may not be able to run a model with such operators. We + // consider these models as TOSA incompatible. + if (op.getName().getStringRef().equals("tosa.custom")) { + return false; + } + + return true; + } + + return false; +} + template <typename T> std::string TOSAChecker::GetMLIRRepresentation(T &&op) { std::string value; diff --git a/tosa_checker/tosa_checker.h b/tosa_checker/tosa_checker.h index d7750ea..f0e473a 100644 --- a/tosa_checker/tosa_checker.h +++ b/tosa_checker/tosa_checker.h @@ -47,6 +47,8 @@ class TOSAChecker { std::string GetMLIRTOSAModelRepresentation(bool elide_large_attrs); private: + static bool IsTOSACompatibleOp(mlir::Operation& op); + template <typename T> static std::string GetMLIRRepresentation(T&& op); |