diff options
author | Tai Ly <tai.ly@arm.com> | 2023-03-16 22:24:05 +0000 |
---|---|---|
committer | Tai Ly <tai.ly@arm.com> | 2023-04-06 20:21:27 +0000 |
commit | 4e9a977ae5c95e2a0d323951a8cffcade9b0cbba (patch) | |
tree | 89f40ed571faab3ad0a206e52a87cdb715d2d0c3 /reference_model/src/ops/control_flow.cc | |
parent | b1f25015d4be6c9b8cd399d7e14fea98cd2f01f5 (diff) | |
download | reference_model-4e9a977ae5c95e2a0d323951a8cffcade9b0cbba.tar.gz |
[reference model] support multiple regions
This allows IF/WHILE serialization to use regions
instead of blocks to serialize nested regions.
For backward compatibility, both region and block
serialization are supported for IF/WHILE ops.
Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: Icf935561f9f5db38767ff76410bcd36896119395
Diffstat (limited to 'reference_model/src/ops/control_flow.cc')
-rw-r--r-- | reference_model/src/ops/control_flow.cc | 39 |
1 files changed, 31 insertions, 8 deletions
diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc index 942652d..f573d5b 100644 --- a/reference_model/src/ops/control_flow.cc +++ b/reference_model/src/ops/control_flow.cc @@ -181,10 +181,22 @@ int OpCondIf::checkTensorAttributes() cond = dynamic_cast<TosaReference::Tensor0<bool>*>(inputs[0]); ASSERT_MEM(cond); - auto region_name = getParentSGT()->getRegionName(); - auto curr_region = tsh->GetRegionByName(region_name); - then_block = curr_region->GetBlockByName(attribute->then_branch()); - else_block = curr_region->GetBlockByName(attribute->else_branch()); + auto then_region = tsh->GetRegionByName(attribute->then_branch()); + auto else_region = tsh->GetRegionByName(attribute->else_branch()); + if (then_region && else_region) + { + // new serialization: then_branch and else_branch point to regions + then_block = then_region->GetBlocks().front(); + else_block = else_region->GetBlocks().front(); + } + else + { + // old serialization: then_branch and else_branch point to blocks in curr_region + auto region_name = getParentSGT()->getRegionName(); + auto curr_region = tsh->GetRegionByName(region_name); + then_block = curr_region->GetBlockByName(attribute->then_branch()); + else_block = curr_region->GetBlockByName(attribute->else_branch()); + } ERROR_IF(!then_block, "OpCondIf: fail to resolve then_branch %s", attribute->then_branch().c_str()); @@ -309,10 +321,21 @@ int OpWhileLoop::checkTensorAttributes() return 1; } - auto region_name = getParentSGT()->getRegionName(); - auto curr_region = tsh->GetRegionByName(region_name); - cond_block = curr_region->GetBlockByName(attribute->cond_branch()); - body_block = curr_region->GetBlockByName(attribute->body_branch()); + auto cond_region = tsh->GetRegionByName(attribute->cond_branch()); + auto body_region = tsh->GetRegionByName(attribute->body_branch()); + if (cond_region && body_region) + { + // new serialization: then_branch and else_branch point to regions + cond_block = cond_region->GetBlocks().front(); + body_block = body_region->GetBlocks().front(); + } + else + { + auto region_name = getParentSGT()->getRegionName(); + auto curr_region = tsh->GetRegionByName(region_name); + cond_block = curr_region->GetBlockByName(attribute->cond_branch()); + body_block = curr_region->GetBlockByName(attribute->body_branch()); + } ERROR_IF(!cond_block, "OpWhileLoop: fail to resolve cond_branch %s", attribute->cond_branch().c_str()); ERROR_IF(!body_block, "OpWhileLoop: fail to resolve body_branch %s", attribute->body_branch().c_str()); |