diff options
Diffstat (limited to 'tosa_checker/tosa_checker.cc')
-rw-r--r-- | tosa_checker/tosa_checker.cc | 39 |
1 files changed, 28 insertions, 11 deletions
diff --git a/tosa_checker/tosa_checker.cc b/tosa_checker/tosa_checker.cc index 714cab3..d42b826 100644 --- a/tosa_checker/tosa_checker.cc +++ b/tosa_checker/tosa_checker.cc @@ -47,10 +47,10 @@ bool TOSAChecker::IsTOSACompatible() { bool is_tosa_compatible = true; for (auto func : m_tosa_model->getOps<mlir::func::FuncOp>()) { func.walk([&](mlir::Operation *op) { - // Ignore func namespace - const mlir::Dialect *dialect = op->getDialect(); - if (!dialect || (!dialect->getNamespace().equals("tosa") && - !dialect->getNamespace().equals("func"))) { + // Ignore func dialect ops + const bool is_func = + op->getDialect() && op->getDialect()->getNamespace().equals("func"); + if (!is_func && !IsTOSACompatibleOp(*op)) { is_tosa_compatible = false; return mlir::WalkResult::interrupt(); } @@ -69,10 +69,10 @@ std::vector<TOSAChecker::Operator> TOSAChecker::GetTOSACompatibilityForOps( std::unordered_set<mlir::Location> tosa_incompatible_locs; for (auto func : m_tosa_model->getOps<mlir::func::FuncOp>()) { func.walk([&](mlir::Operation *op) { - // Ignore func namespace - const mlir::Dialect *dialect = op->getDialect(); - if (!dialect || (!dialect->getNamespace().equals("tosa") && - !dialect->getNamespace().equals("func"))) { + // Ignore func dialect ops + const bool is_func = + op->getDialect() && op->getDialect()->getNamespace().equals("func"); + if (!is_func && !IsTOSACompatibleOp(*op)) { tosa_incompatible_locs.insert(op->getLoc()); } }); @@ -85,8 +85,9 @@ std::vector<TOSAChecker::Operator> TOSAChecker::GetTOSACompatibilityForOps( for (auto func : m_model->getOps<mlir::func::FuncOp>()) { func.walk([&](mlir::Operation *op) { // Ignore func namespace - const mlir::Dialect *dialect = op->getDialect(); - if (!dialect || !dialect->getNamespace().equals("func")) { + const bool is_func = + op->getDialect() && op->getDialect()->getNamespace().equals("func"); + if (!is_func) { const bool is_tosa_compatible = tosa_incompatible_locs.find(op->getLoc()) == tosa_incompatible_locs.end(); @@ -102,7 +103,7 @@ std::vector<TOSAChecker::Operator> TOSAChecker::GetUsedTOSAOps( bool elide_large_attrs) { std::vector<Operator> tosa_ops; for (mlir::Operation *op : GetTOSAOps(*m_tosa_model)) { - const bool is_tosa_compatible = true; + const bool is_tosa_compatible = IsTOSACompatibleOp(*op); tosa_ops.push_back(ToOperator(*op, is_tosa_compatible, elide_large_attrs)); } @@ -118,6 +119,22 @@ std::string TOSAChecker::GetMLIRTOSAModelRepresentation( return GetMLIRRepresentation(*m_tosa_model, elide_large_attrs); } +bool TOSAChecker::IsTOSACompatibleOp(mlir::Operation &op) { + const mlir::Dialect *dialect = op.getDialect(); + if (dialect && dialect->getNamespace().equals("tosa")) { + // Due to the opaque nature of the tosa.custom operator, a TOSA compliant + // system may not be able to run a model with such operators. We + // consider these models as TOSA incompatible. + if (op.getName().getStringRef().equals("tosa.custom")) { + return false; + } + + return true; + } + + return false; +} + template <typename T> std::string TOSAChecker::GetMLIRRepresentation(T &&op) { std::string value; |