aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorThibaut Goetghebuer-Planchon <thibaut.goetghebuer-planchon@arm.com>2022-11-23 11:42:43 +0000
committerThibaut Goetghebuer-Planchon <thibaut.goetghebuer-planchon@arm.com>2022-12-13 10:53:38 +0000
commit9fb745d92173abfa270e99bd5c9bd7cf85bfeb31 (patch)
tree3a430220145a184a940131d6f7b0465dacdbde3b
parenta2bcf5f818699082adfd346eba216d96f14d6e6c (diff)
downloadtosa_checker-9fb745d92173abfa270e99bd5c9bd7cf85bfeb31.tar.gz
Models with tfl.custom ops are considered to be non TOSA compliant
Change-Id: Ib31b44b3819bbdf517b96d834879155eebb4f09c
-rw-r--r--tests/test_tosa_checker.py56
-rw-r--r--tosa_checker/tosa_checker.cc39
-rw-r--r--tosa_checker/tosa_checker.h2
3 files changed, 86 insertions, 11 deletions
diff --git a/tests/test_tosa_checker.py b/tests/test_tosa_checker.py
index eb49e65..b24073a 100644
--- a/tests/test_tosa_checker.py
+++ b/tests/test_tosa_checker.py
@@ -40,6 +40,29 @@ def build_tosa_non_compat_model():
@pytest.fixture(scope="module")
+def build_tosa_non_compat_model_custom_op():
+ @tf.function(
+ experimental_implements='name: "exp_log" \
+ attr { \
+ key: "tfl_fusable_op" \
+ value { b: true } \
+ }'
+ )
+ def exp_log(x):
+ x = tf.math.exp(x)
+ x = tf.math.log(x)
+
+ return x
+
+ input = tf.keras.layers.Input(shape=(16,))
+ x = tf.keras.layers.Lambda(exp_log)(input)
+ x = tf.keras.layers.Dense(8, activation="relu")(x)
+ model = tf.keras.models.Model(inputs=[input], outputs=x)
+
+ return model
+
+
+@pytest.fixture(scope="module")
def build_tosa_compat_model():
input = tf.keras.layers.Input(shape=(16,))
x = tf.keras.layers.Dense(8, activation="relu")(input)
@@ -63,6 +86,15 @@ def non_compat_file(build_tosa_non_compat_model):
@pytest.fixture(scope="module")
+def non_compat_file_custom_op(build_tosa_non_compat_model_custom_op):
+ tflite_model = create_tflite(build_tosa_non_compat_model_custom_op)
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ file = os.path.join(tmp_dir, "test.tflite")
+ open(file, "wb").write(tflite_model)
+ yield file
+
+
+@pytest.fixture(scope="module")
def compat_file(build_tosa_compat_model):
tflite_model = create_tflite(build_tosa_compat_model)
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -113,6 +145,30 @@ class TestTosaCompatibilityTool:
["tosa.reshape", True],
]
+ def test_tosa_non_compat_model_with_custom_op(self, non_compat_file_custom_op):
+ checker = tosa_checker.TOSAChecker(model_path=non_compat_file_custom_op)
+ tosa_compatible = checker.is_tosa_compatible()
+ assert tosa_compatible == False
+
+ ops = checker._get_tosa_compatibility_for_ops()
+ assert type(ops) == list
+ assert [[op.name, op.is_tosa_compatible] for op in ops] == [
+ ["tfl.custom", False],
+ ["tfl.pseudo_const", True],
+ ["tfl.no_value", True],
+ ["tfl.fully_connected", True],
+ ]
+
+ tosa_ops = checker._get_used_tosa_ops()
+ assert type(tosa_ops) == list
+ assert [[op.name, op.is_tosa_compatible] for op in tosa_ops] == [
+ ["tosa.const", True],
+ ["tosa.const", True],
+ ["tosa.custom", False],
+ ["tosa.fully_connected", True],
+ ["tosa.clamp", True],
+ ]
+
def test_tosa_compat_model(self, compat_file):
checker = tosa_checker.TOSAChecker(model_path=compat_file)
tosa_compatible = checker.is_tosa_compatible()
diff --git a/tosa_checker/tosa_checker.cc b/tosa_checker/tosa_checker.cc
index 714cab3..d42b826 100644
--- a/tosa_checker/tosa_checker.cc
+++ b/tosa_checker/tosa_checker.cc
@@ -47,10 +47,10 @@ bool TOSAChecker::IsTOSACompatible() {
bool is_tosa_compatible = true;
for (auto func : m_tosa_model->getOps<mlir::func::FuncOp>()) {
func.walk([&](mlir::Operation *op) {
- // Ignore func namespace
- const mlir::Dialect *dialect = op->getDialect();
- if (!dialect || (!dialect->getNamespace().equals("tosa") &&
- !dialect->getNamespace().equals("func"))) {
+ // Ignore func dialect ops
+ const bool is_func =
+ op->getDialect() && op->getDialect()->getNamespace().equals("func");
+ if (!is_func && !IsTOSACompatibleOp(*op)) {
is_tosa_compatible = false;
return mlir::WalkResult::interrupt();
}
@@ -69,10 +69,10 @@ std::vector<TOSAChecker::Operator> TOSAChecker::GetTOSACompatibilityForOps(
std::unordered_set<mlir::Location> tosa_incompatible_locs;
for (auto func : m_tosa_model->getOps<mlir::func::FuncOp>()) {
func.walk([&](mlir::Operation *op) {
- // Ignore func namespace
- const mlir::Dialect *dialect = op->getDialect();
- if (!dialect || (!dialect->getNamespace().equals("tosa") &&
- !dialect->getNamespace().equals("func"))) {
+ // Ignore func dialect ops
+ const bool is_func =
+ op->getDialect() && op->getDialect()->getNamespace().equals("func");
+ if (!is_func && !IsTOSACompatibleOp(*op)) {
tosa_incompatible_locs.insert(op->getLoc());
}
});
@@ -85,8 +85,9 @@ std::vector<TOSAChecker::Operator> TOSAChecker::GetTOSACompatibilityForOps(
for (auto func : m_model->getOps<mlir::func::FuncOp>()) {
func.walk([&](mlir::Operation *op) {
// Ignore func namespace
- const mlir::Dialect *dialect = op->getDialect();
- if (!dialect || !dialect->getNamespace().equals("func")) {
+ const bool is_func =
+ op->getDialect() && op->getDialect()->getNamespace().equals("func");
+ if (!is_func) {
const bool is_tosa_compatible =
tosa_incompatible_locs.find(op->getLoc()) ==
tosa_incompatible_locs.end();
@@ -102,7 +103,7 @@ std::vector<TOSAChecker::Operator> TOSAChecker::GetUsedTOSAOps(
bool elide_large_attrs) {
std::vector<Operator> tosa_ops;
for (mlir::Operation *op : GetTOSAOps(*m_tosa_model)) {
- const bool is_tosa_compatible = true;
+ const bool is_tosa_compatible = IsTOSACompatibleOp(*op);
tosa_ops.push_back(ToOperator(*op, is_tosa_compatible, elide_large_attrs));
}
@@ -118,6 +119,22 @@ std::string TOSAChecker::GetMLIRTOSAModelRepresentation(
return GetMLIRRepresentation(*m_tosa_model, elide_large_attrs);
}
+bool TOSAChecker::IsTOSACompatibleOp(mlir::Operation &op) {
+ const mlir::Dialect *dialect = op.getDialect();
+ if (dialect && dialect->getNamespace().equals("tosa")) {
+ // Due to the opaque nature of the tosa.custom operator, a TOSA compliant
+ // system may not be able to run a model with such operators. We
+ // consider these models as TOSA incompatible.
+ if (op.getName().getStringRef().equals("tosa.custom")) {
+ return false;
+ }
+
+ return true;
+ }
+
+ return false;
+}
+
template <typename T>
std::string TOSAChecker::GetMLIRRepresentation(T &&op) {
std::string value;
diff --git a/tosa_checker/tosa_checker.h b/tosa_checker/tosa_checker.h
index d7750ea..f0e473a 100644
--- a/tosa_checker/tosa_checker.h
+++ b/tosa_checker/tosa_checker.h
@@ -47,6 +47,8 @@ class TOSAChecker {
std::string GetMLIRTOSAModelRepresentation(bool elide_large_attrs);
private:
+ static bool IsTOSACompatibleOp(mlir::Operation& op);
+
template <typename T>
static std::string GetMLIRRepresentation(T&& op);