diff options
-rw-r--r-- | reference_model/src/main.cpp | 2 | ||||
-rw-r--r-- | reference_model/src/ops/control_flow.cc | 39 | ||||
-rw-r--r-- | reference_model/src/subgraph_traverser.cc | 9 |
3 files changed, 38 insertions, 12 deletions
diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp index 0375a48..ccc65f9 100644 --- a/reference_model/src/main.cpp +++ b/reference_model/src/main.cpp @@ -83,7 +83,7 @@ int main(int argc, char** argv) FATAL_ERROR("Unable to load graph"); } - SubgraphTraverser main_gt(tsh.GetMainRegion()->GetBlockByName("main"), &tsh, nullptr); + SubgraphTraverser main_gt(tsh.GetMainRegion()->GetBlocks().front(), &tsh, nullptr); if (main_gt.initializeGraph()) { diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc index 942652d..f573d5b 100644 --- a/reference_model/src/ops/control_flow.cc +++ b/reference_model/src/ops/control_flow.cc @@ -181,10 +181,22 @@ int OpCondIf::checkTensorAttributes() cond = dynamic_cast<TosaReference::Tensor0<bool>*>(inputs[0]); ASSERT_MEM(cond); - auto region_name = getParentSGT()->getRegionName(); - auto curr_region = tsh->GetRegionByName(region_name); - then_block = curr_region->GetBlockByName(attribute->then_branch()); - else_block = curr_region->GetBlockByName(attribute->else_branch()); + auto then_region = tsh->GetRegionByName(attribute->then_branch()); + auto else_region = tsh->GetRegionByName(attribute->else_branch()); + if (then_region && else_region) + { + // new serialization: then_branch and else_branch point to regions + then_block = then_region->GetBlocks().front(); + else_block = else_region->GetBlocks().front(); + } + else + { + // old serialization: then_branch and else_branch point to blocks in curr_region + auto region_name = getParentSGT()->getRegionName(); + auto curr_region = tsh->GetRegionByName(region_name); + then_block = curr_region->GetBlockByName(attribute->then_branch()); + else_block = curr_region->GetBlockByName(attribute->else_branch()); + } ERROR_IF(!then_block, "OpCondIf: fail to resolve then_branch %s", attribute->then_branch().c_str()); @@ -309,10 +321,21 @@ int OpWhileLoop::checkTensorAttributes() return 1; } - auto region_name = getParentSGT()->getRegionName(); - auto curr_region = tsh->GetRegionByName(region_name); - cond_block = curr_region->GetBlockByName(attribute->cond_branch()); - body_block = curr_region->GetBlockByName(attribute->body_branch()); + auto cond_region = tsh->GetRegionByName(attribute->cond_branch()); + auto body_region = tsh->GetRegionByName(attribute->body_branch()); + if (cond_region && body_region) + { + // new serialization: then_branch and else_branch point to regions + cond_block = cond_region->GetBlocks().front(); + body_block = body_region->GetBlocks().front(); + } + else + { + auto region_name = getParentSGT()->getRegionName(); + auto curr_region = tsh->GetRegionByName(region_name); + cond_block = curr_region->GetBlockByName(attribute->cond_branch()); + body_block = curr_region->GetBlockByName(attribute->body_branch()); + } ERROR_IF(!cond_block, "OpWhileLoop: fail to resolve cond_branch %s", attribute->cond_branch().c_str()); ERROR_IF(!body_block, "OpWhileLoop: fail to resolve body_branch %s", attribute->body_branch().c_str()); diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index 8867ada..e7641ba 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -123,11 +123,14 @@ int SubgraphTraverser::initializeGraph() std::vector<TosaSerializationTensor*> ser_tensor_vec; // Get all the serialized tensors from TosaSerializationHandler. - for (auto block: tsh->GetMainRegion()->GetBlocks()) + for (auto region : tsh->GetRegions()) { - for (auto ser_tensor : block->GetTensors()) + for (auto block : region->GetBlocks()) { - ser_tensor_vec.push_back(ser_tensor); + for (auto ser_tensor : block->GetTensors()) + { + ser_tensor_vec.push_back(ser_tensor); + } } } |