diff options
Diffstat (limited to 'reference_model/src/ops/control_flow.cc')
-rw-r--r-- | reference_model/src/ops/control_flow.cc | 39 |
1 files changed, 31 insertions, 8 deletions
diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc index 942652d..f573d5b 100644 --- a/reference_model/src/ops/control_flow.cc +++ b/reference_model/src/ops/control_flow.cc @@ -181,10 +181,22 @@ int OpCondIf::checkTensorAttributes() cond = dynamic_cast<TosaReference::Tensor0<bool>*>(inputs[0]); ASSERT_MEM(cond); - auto region_name = getParentSGT()->getRegionName(); - auto curr_region = tsh->GetRegionByName(region_name); - then_block = curr_region->GetBlockByName(attribute->then_branch()); - else_block = curr_region->GetBlockByName(attribute->else_branch()); + auto then_region = tsh->GetRegionByName(attribute->then_branch()); + auto else_region = tsh->GetRegionByName(attribute->else_branch()); + if (then_region && else_region) + { + // new serialization: then_branch and else_branch point to regions + then_block = then_region->GetBlocks().front(); + else_block = else_region->GetBlocks().front(); + } + else + { + // old serialization: then_branch and else_branch point to blocks in curr_region + auto region_name = getParentSGT()->getRegionName(); + auto curr_region = tsh->GetRegionByName(region_name); + then_block = curr_region->GetBlockByName(attribute->then_branch()); + else_block = curr_region->GetBlockByName(attribute->else_branch()); + } ERROR_IF(!then_block, "OpCondIf: fail to resolve then_branch %s", attribute->then_branch().c_str()); @@ -309,10 +321,21 @@ int OpWhileLoop::checkTensorAttributes() return 1; } - auto region_name = getParentSGT()->getRegionName(); - auto curr_region = tsh->GetRegionByName(region_name); - cond_block = curr_region->GetBlockByName(attribute->cond_branch()); - body_block = curr_region->GetBlockByName(attribute->body_branch()); + auto cond_region = tsh->GetRegionByName(attribute->cond_branch()); + auto body_region = tsh->GetRegionByName(attribute->body_branch()); + if (cond_region && body_region) + { + // new serialization: then_branch and else_branch point to regions + cond_block = cond_region->GetBlocks().front(); + body_block = body_region->GetBlocks().front(); + } + else + { + auto region_name = getParentSGT()->getRegionName(); + auto curr_region = tsh->GetRegionByName(region_name); + cond_block = curr_region->GetBlockByName(attribute->cond_branch()); + body_block = curr_region->GetBlockByName(attribute->body_branch()); + } ERROR_IF(!cond_block, "OpWhileLoop: fail to resolve cond_branch %s", attribute->cond_branch().c_str()); ERROR_IF(!body_block, "OpWhileLoop: fail to resolve body_branch %s", attribute->body_branch().c_str()); |