From 4e9a977ae5c95e2a0d323951a8cffcade9b0cbba Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Thu, 16 Mar 2023 22:24:05 +0000 Subject: [reference model] support multiple regions This allows IF/WHILE serialization to use regions instead of blocks to serialize nested regions. For backward compatibility, both region and block serialization are supported for IF/WHILE ops. Signed-off-by: Tai Ly Change-Id: Icf935561f9f5db38767ff76410bcd36896119395 --- reference_model/src/ops/control_flow.cc | 39 ++++++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 8 deletions(-) (limited to 'reference_model/src/ops') diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc index 942652d..f573d5b 100644 --- a/reference_model/src/ops/control_flow.cc +++ b/reference_model/src/ops/control_flow.cc @@ -181,10 +181,22 @@ int OpCondIf::checkTensorAttributes() cond = dynamic_cast*>(inputs[0]); ASSERT_MEM(cond); - auto region_name = getParentSGT()->getRegionName(); - auto curr_region = tsh->GetRegionByName(region_name); - then_block = curr_region->GetBlockByName(attribute->then_branch()); - else_block = curr_region->GetBlockByName(attribute->else_branch()); + auto then_region = tsh->GetRegionByName(attribute->then_branch()); + auto else_region = tsh->GetRegionByName(attribute->else_branch()); + if (then_region && else_region) + { + // new serialization: then_branch and else_branch point to regions + then_block = then_region->GetBlocks().front(); + else_block = else_region->GetBlocks().front(); + } + else + { + // old serialization: then_branch and else_branch point to blocks in curr_region + auto region_name = getParentSGT()->getRegionName(); + auto curr_region = tsh->GetRegionByName(region_name); + then_block = curr_region->GetBlockByName(attribute->then_branch()); + else_block = curr_region->GetBlockByName(attribute->else_branch()); + } ERROR_IF(!then_block, "OpCondIf: fail to resolve then_branch %s", attribute->then_branch().c_str()); @@ -309,10 +321,21 @@ int OpWhileLoop::checkTensorAttributes() return 1; } - auto region_name = getParentSGT()->getRegionName(); - auto curr_region = tsh->GetRegionByName(region_name); - cond_block = curr_region->GetBlockByName(attribute->cond_branch()); - body_block = curr_region->GetBlockByName(attribute->body_branch()); + auto cond_region = tsh->GetRegionByName(attribute->cond_branch()); + auto body_region = tsh->GetRegionByName(attribute->body_branch()); + if (cond_region && body_region) + { + // new serialization: then_branch and else_branch point to regions + cond_block = cond_region->GetBlocks().front(); + body_block = body_region->GetBlocks().front(); + } + else + { + auto region_name = getParentSGT()->getRegionName(); + auto curr_region = tsh->GetRegionByName(region_name); + cond_block = curr_region->GetBlockByName(attribute->cond_branch()); + body_block = curr_region->GetBlockByName(attribute->body_branch()); + } ERROR_IF(!cond_block, "OpWhileLoop: fail to resolve cond_branch %s", attribute->cond_branch().c_str()); ERROR_IF(!body_block, "OpWhileLoop: fail to resolve body_branch %s", attribute->body_branch().c_str()); -- cgit v1.2.1