From 4e9a977ae5c95e2a0d323951a8cffcade9b0cbba Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Thu, 16 Mar 2023 22:24:05 +0000 Subject: [reference model] support multiple regions This allows IF/WHILE serialization to use regions instead of blocks to serialize nested regions. For backward compatibility, both region and block serialization are supported for IF/WHILE ops. Signed-off-by: Tai Ly Change-Id: Icf935561f9f5db38767ff76410bcd36896119395 --- reference_model/src/main.cpp | 2 +- reference_model/src/ops/control_flow.cc | 39 ++++++++++++++++++++++++------- reference_model/src/subgraph_traverser.cc | 9 ++++--- 3 files changed, 38 insertions(+), 12 deletions(-) diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp index 0375a48..ccc65f9 100644 --- a/reference_model/src/main.cpp +++ b/reference_model/src/main.cpp @@ -83,7 +83,7 @@ int main(int argc, char** argv) FATAL_ERROR("Unable to load graph"); } - SubgraphTraverser main_gt(tsh.GetMainRegion()->GetBlockByName("main"), &tsh, nullptr); + SubgraphTraverser main_gt(tsh.GetMainRegion()->GetBlocks().front(), &tsh, nullptr); if (main_gt.initializeGraph()) { diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc index 942652d..f573d5b 100644 --- a/reference_model/src/ops/control_flow.cc +++ b/reference_model/src/ops/control_flow.cc @@ -181,10 +181,22 @@ int OpCondIf::checkTensorAttributes() cond = dynamic_cast*>(inputs[0]); ASSERT_MEM(cond); - auto region_name = getParentSGT()->getRegionName(); - auto curr_region = tsh->GetRegionByName(region_name); - then_block = curr_region->GetBlockByName(attribute->then_branch()); - else_block = curr_region->GetBlockByName(attribute->else_branch()); + auto then_region = tsh->GetRegionByName(attribute->then_branch()); + auto else_region = tsh->GetRegionByName(attribute->else_branch()); + if (then_region && else_region) + { + // new serialization: then_branch and else_branch point to regions + then_block = then_region->GetBlocks().front(); + else_block = else_region->GetBlocks().front(); + } + else + { + // old serialization: then_branch and else_branch point to blocks in curr_region + auto region_name = getParentSGT()->getRegionName(); + auto curr_region = tsh->GetRegionByName(region_name); + then_block = curr_region->GetBlockByName(attribute->then_branch()); + else_block = curr_region->GetBlockByName(attribute->else_branch()); + } ERROR_IF(!then_block, "OpCondIf: fail to resolve then_branch %s", attribute->then_branch().c_str()); @@ -309,10 +321,21 @@ int OpWhileLoop::checkTensorAttributes() return 1; } - auto region_name = getParentSGT()->getRegionName(); - auto curr_region = tsh->GetRegionByName(region_name); - cond_block = curr_region->GetBlockByName(attribute->cond_branch()); - body_block = curr_region->GetBlockByName(attribute->body_branch()); + auto cond_region = tsh->GetRegionByName(attribute->cond_branch()); + auto body_region = tsh->GetRegionByName(attribute->body_branch()); + if (cond_region && body_region) + { + // new serialization: then_branch and else_branch point to regions + cond_block = cond_region->GetBlocks().front(); + body_block = body_region->GetBlocks().front(); + } + else + { + auto region_name = getParentSGT()->getRegionName(); + auto curr_region = tsh->GetRegionByName(region_name); + cond_block = curr_region->GetBlockByName(attribute->cond_branch()); + body_block = curr_region->GetBlockByName(attribute->body_branch()); + } ERROR_IF(!cond_block, "OpWhileLoop: fail to resolve cond_branch %s", attribute->cond_branch().c_str()); ERROR_IF(!body_block, "OpWhileLoop: fail to resolve body_branch %s", attribute->body_branch().c_str()); diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index 8867ada..e7641ba 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -123,11 +123,14 @@ int SubgraphTraverser::initializeGraph() std::vector ser_tensor_vec; // Get all the serialized tensors from TosaSerializationHandler. - for (auto block: tsh->GetMainRegion()->GetBlocks()) + for (auto region : tsh->GetRegions()) { - for (auto ser_tensor : block->GetTensors()) + for (auto block : region->GetBlocks()) { - ser_tensor_vec.push_back(ser_tensor); + for (auto ser_tensor : block->GetTensors()) + { + ser_tensor_vec.push_back(ser_tensor); + } } } -- cgit v1.2.1