aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-03-16 22:24:05 +0000
committerTai Ly <tai.ly@arm.com>2023-04-06 20:21:27 +0000
commit4e9a977ae5c95e2a0d323951a8cffcade9b0cbba (patch)
tree89f40ed571faab3ad0a206e52a87cdb715d2d0c3
parentb1f25015d4be6c9b8cd399d7e14fea98cd2f01f5 (diff)
downloadreference_model-4e9a977ae5c95e2a0d323951a8cffcade9b0cbba.tar.gz
[reference model] support multiple regions
This allows IF/WHILE serialization to use regions instead of blocks to serialize nested regions. For backward compatibility, both region and block serialization are supported for IF/WHILE ops. Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: Icf935561f9f5db38767ff76410bcd36896119395
-rw-r--r--reference_model/src/main.cpp2
-rw-r--r--reference_model/src/ops/control_flow.cc39
-rw-r--r--reference_model/src/subgraph_traverser.cc9
3 files changed, 38 insertions, 12 deletions
diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp
index 0375a48..ccc65f9 100644
--- a/reference_model/src/main.cpp
+++ b/reference_model/src/main.cpp
@@ -83,7 +83,7 @@ int main(int argc, char** argv)
FATAL_ERROR("Unable to load graph");
}
- SubgraphTraverser main_gt(tsh.GetMainRegion()->GetBlockByName("main"), &tsh, nullptr);
+ SubgraphTraverser main_gt(tsh.GetMainRegion()->GetBlocks().front(), &tsh, nullptr);
if (main_gt.initializeGraph())
{
diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc
index 942652d..f573d5b 100644
--- a/reference_model/src/ops/control_flow.cc
+++ b/reference_model/src/ops/control_flow.cc
@@ -181,10 +181,22 @@ int OpCondIf::checkTensorAttributes()
cond = dynamic_cast<TosaReference::Tensor0<bool>*>(inputs[0]);
ASSERT_MEM(cond);
- auto region_name = getParentSGT()->getRegionName();
- auto curr_region = tsh->GetRegionByName(region_name);
- then_block = curr_region->GetBlockByName(attribute->then_branch());
- else_block = curr_region->GetBlockByName(attribute->else_branch());
+ auto then_region = tsh->GetRegionByName(attribute->then_branch());
+ auto else_region = tsh->GetRegionByName(attribute->else_branch());
+ if (then_region && else_region)
+ {
+ // new serialization: then_branch and else_branch point to regions
+ then_block = then_region->GetBlocks().front();
+ else_block = else_region->GetBlocks().front();
+ }
+ else
+ {
+ // old serialization: then_branch and else_branch point to blocks in curr_region
+ auto region_name = getParentSGT()->getRegionName();
+ auto curr_region = tsh->GetRegionByName(region_name);
+ then_block = curr_region->GetBlockByName(attribute->then_branch());
+ else_block = curr_region->GetBlockByName(attribute->else_branch());
+ }
ERROR_IF(!then_block, "OpCondIf: fail to resolve then_branch %s", attribute->then_branch().c_str());
@@ -309,10 +321,21 @@ int OpWhileLoop::checkTensorAttributes()
return 1;
}
- auto region_name = getParentSGT()->getRegionName();
- auto curr_region = tsh->GetRegionByName(region_name);
- cond_block = curr_region->GetBlockByName(attribute->cond_branch());
- body_block = curr_region->GetBlockByName(attribute->body_branch());
+ auto cond_region = tsh->GetRegionByName(attribute->cond_branch());
+ auto body_region = tsh->GetRegionByName(attribute->body_branch());
+ if (cond_region && body_region)
+ {
+ // new serialization: then_branch and else_branch point to regions
+ cond_block = cond_region->GetBlocks().front();
+ body_block = body_region->GetBlocks().front();
+ }
+ else
+ {
+ auto region_name = getParentSGT()->getRegionName();
+ auto curr_region = tsh->GetRegionByName(region_name);
+ cond_block = curr_region->GetBlockByName(attribute->cond_branch());
+ body_block = curr_region->GetBlockByName(attribute->body_branch());
+ }
ERROR_IF(!cond_block, "OpWhileLoop: fail to resolve cond_branch %s", attribute->cond_branch().c_str());
ERROR_IF(!body_block, "OpWhileLoop: fail to resolve body_branch %s", attribute->body_branch().c_str());
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index 8867ada..e7641ba 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -123,11 +123,14 @@ int SubgraphTraverser::initializeGraph()
std::vector<TosaSerializationTensor*> ser_tensor_vec;
// Get all the serialized tensors from TosaSerializationHandler.
- for (auto block: tsh->GetMainRegion()->GetBlocks())
+ for (auto region : tsh->GetRegions())
{
- for (auto ser_tensor : block->GetTensors())
+ for (auto block : region->GetBlocks())
{
- ser_tensor_vec.push_back(ser_tensor);
+ for (auto ser_tensor : block->GetTensors())
+ {
+ ser_tensor_vec.push_back(ser_tensor);
+ }
}
}