aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/control_flow.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/control_flow.cc')
-rw-r--r--reference_model/src/ops/control_flow.cc39
1 files changed, 31 insertions, 8 deletions
diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc
index 942652d..f573d5b 100644
--- a/reference_model/src/ops/control_flow.cc
+++ b/reference_model/src/ops/control_flow.cc
@@ -181,10 +181,22 @@ int OpCondIf::checkTensorAttributes()
cond = dynamic_cast<TosaReference::Tensor0<bool>*>(inputs[0]);
ASSERT_MEM(cond);
- auto region_name = getParentSGT()->getRegionName();
- auto curr_region = tsh->GetRegionByName(region_name);
- then_block = curr_region->GetBlockByName(attribute->then_branch());
- else_block = curr_region->GetBlockByName(attribute->else_branch());
+ auto then_region = tsh->GetRegionByName(attribute->then_branch());
+ auto else_region = tsh->GetRegionByName(attribute->else_branch());
+ if (then_region && else_region)
+ {
+ // new serialization: then_branch and else_branch point to regions
+ then_block = then_region->GetBlocks().front();
+ else_block = else_region->GetBlocks().front();
+ }
+ else
+ {
+ // old serialization: then_branch and else_branch point to blocks in curr_region
+ auto region_name = getParentSGT()->getRegionName();
+ auto curr_region = tsh->GetRegionByName(region_name);
+ then_block = curr_region->GetBlockByName(attribute->then_branch());
+ else_block = curr_region->GetBlockByName(attribute->else_branch());
+ }
ERROR_IF(!then_block, "OpCondIf: fail to resolve then_branch %s", attribute->then_branch().c_str());
@@ -309,10 +321,21 @@ int OpWhileLoop::checkTensorAttributes()
return 1;
}
- auto region_name = getParentSGT()->getRegionName();
- auto curr_region = tsh->GetRegionByName(region_name);
- cond_block = curr_region->GetBlockByName(attribute->cond_branch());
- body_block = curr_region->GetBlockByName(attribute->body_branch());
+ auto cond_region = tsh->GetRegionByName(attribute->cond_branch());
+ auto body_region = tsh->GetRegionByName(attribute->body_branch());
+ if (cond_region && body_region)
+ {
+ // new serialization: then_branch and else_branch point to regions
+ cond_block = cond_region->GetBlocks().front();
+ body_block = body_region->GetBlocks().front();
+ }
+ else
+ {
+ auto region_name = getParentSGT()->getRegionName();
+ auto curr_region = tsh->GetRegionByName(region_name);
+ cond_block = curr_region->GetBlockByName(attribute->cond_branch());
+ body_block = curr_region->GetBlockByName(attribute->body_branch());
+ }
ERROR_IF(!cond_block, "OpWhileLoop: fail to resolve cond_branch %s", attribute->cond_branch().c_str());
ERROR_IF(!body_block, "OpWhileLoop: fail to resolve body_branch %s", attribute->body_branch().c_str());