aboutsummaryrefslogtreecommitdiff
path: root/tosa_checker/tosa_checker.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tosa_checker/tosa_checker.cc')
-rw-r--r--tosa_checker/tosa_checker.cc39
1 files changed, 28 insertions, 11 deletions
diff --git a/tosa_checker/tosa_checker.cc b/tosa_checker/tosa_checker.cc
index 714cab3..d42b826 100644
--- a/tosa_checker/tosa_checker.cc
+++ b/tosa_checker/tosa_checker.cc
@@ -47,10 +47,10 @@ bool TOSAChecker::IsTOSACompatible() {
bool is_tosa_compatible = true;
for (auto func : m_tosa_model->getOps<mlir::func::FuncOp>()) {
func.walk([&](mlir::Operation *op) {
- // Ignore func namespace
- const mlir::Dialect *dialect = op->getDialect();
- if (!dialect || (!dialect->getNamespace().equals("tosa") &&
- !dialect->getNamespace().equals("func"))) {
+ // Ignore func dialect ops
+ const bool is_func =
+ op->getDialect() && op->getDialect()->getNamespace().equals("func");
+ if (!is_func && !IsTOSACompatibleOp(*op)) {
is_tosa_compatible = false;
return mlir::WalkResult::interrupt();
}
@@ -69,10 +69,10 @@ std::vector<TOSAChecker::Operator> TOSAChecker::GetTOSACompatibilityForOps(
std::unordered_set<mlir::Location> tosa_incompatible_locs;
for (auto func : m_tosa_model->getOps<mlir::func::FuncOp>()) {
func.walk([&](mlir::Operation *op) {
- // Ignore func namespace
- const mlir::Dialect *dialect = op->getDialect();
- if (!dialect || (!dialect->getNamespace().equals("tosa") &&
- !dialect->getNamespace().equals("func"))) {
+ // Ignore func dialect ops
+ const bool is_func =
+ op->getDialect() && op->getDialect()->getNamespace().equals("func");
+ if (!is_func && !IsTOSACompatibleOp(*op)) {
tosa_incompatible_locs.insert(op->getLoc());
}
});
@@ -85,8 +85,9 @@ std::vector<TOSAChecker::Operator> TOSAChecker::GetTOSACompatibilityForOps(
for (auto func : m_model->getOps<mlir::func::FuncOp>()) {
func.walk([&](mlir::Operation *op) {
// Ignore func namespace
- const mlir::Dialect *dialect = op->getDialect();
- if (!dialect || !dialect->getNamespace().equals("func")) {
+ const bool is_func =
+ op->getDialect() && op->getDialect()->getNamespace().equals("func");
+ if (!is_func) {
const bool is_tosa_compatible =
tosa_incompatible_locs.find(op->getLoc()) ==
tosa_incompatible_locs.end();
@@ -102,7 +103,7 @@ std::vector<TOSAChecker::Operator> TOSAChecker::GetUsedTOSAOps(
bool elide_large_attrs) {
std::vector<Operator> tosa_ops;
for (mlir::Operation *op : GetTOSAOps(*m_tosa_model)) {
- const bool is_tosa_compatible = true;
+ const bool is_tosa_compatible = IsTOSACompatibleOp(*op);
tosa_ops.push_back(ToOperator(*op, is_tosa_compatible, elide_large_attrs));
}
@@ -118,6 +119,22 @@ std::string TOSAChecker::GetMLIRTOSAModelRepresentation(
return GetMLIRRepresentation(*m_tosa_model, elide_large_attrs);
}
+bool TOSAChecker::IsTOSACompatibleOp(mlir::Operation &op) {
+ const mlir::Dialect *dialect = op.getDialect();
+ if (dialect && dialect->getNamespace().equals("tosa")) {
+ // Due to the opaque nature of the tosa.custom operator, a TOSA compliant
+ // system may not be able to run a model with such operators. We
+ // consider these models as TOSA incompatible.
+ if (op.getName().getStringRef().equals("tosa.custom")) {
+ return false;
+ }
+
+ return true;
+ }
+
+ return false;
+}
+
template <typename T>
std::string TOSAChecker::GetMLIRRepresentation(T &&op) {
std::string value;