diff options
Diffstat (limited to 'reference_model/src/ops/control_flow.cc')
-rw-r--r-- | reference_model/src/ops/control_flow.cc | 38 |
1 files changed, 19 insertions, 19 deletions
diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc index 6bbc587..ac09bbb 100644 --- a/reference_model/src/ops/control_flow.cc +++ b/reference_model/src/ops/control_flow.cc @@ -181,26 +181,26 @@ int OpCondIf::checkTensorAttributes() cond = dynamic_cast<TosaReference::Tensor0<bool>*>(inputs[0]); ASSERT_MEM(cond); - auto then_region = tsh->GetRegionByName(attribute->then_branch()); - auto else_region = tsh->GetRegionByName(attribute->else_branch()); + auto then_region = tsh->GetRegionByName(attribute->then_graph()); + auto else_region = tsh->GetRegionByName(attribute->else_graph()); if (then_region && else_region) { - // new serialization: then_branch and else_branch point to regions + // new serialization: then_graph and else_graph point to regions then_block = then_region->GetBlocks().front(); else_block = else_region->GetBlocks().front(); } else { - // old serialization: then_branch and else_branch point to blocks in curr_region + // old serialization: then_graph and else_graph point to blocks in curr_region auto region_name = getParentSGT()->getRegionName(); auto curr_region = tsh->GetRegionByName(region_name); - then_block = curr_region->GetBlockByName(attribute->then_branch()); - else_block = curr_region->GetBlockByName(attribute->else_branch()); + then_block = curr_region->GetBlockByName(attribute->then_graph()); + else_block = curr_region->GetBlockByName(attribute->else_graph()); } - ERROR_IF(!then_block, "OpCondIf: fail to resolve then_branch %s", attribute->then_branch().c_str()); + ERROR_IF(!then_block, "OpCondIf: fail to resolve then_graph %s", attribute->then_graph().c_str()); - ERROR_IF(!else_block, "OpCondIf: fail to resolve else_branch %s", attribute->else_branch().c_str()); + ERROR_IF(!else_block, "OpCondIf: fail to resolve else_graph %s", attribute->else_graph().c_str()); // Make sure operator input/output matches block input/output // Skip the first rank 0 bool tensor on input list @@ -276,7 +276,7 @@ int OpCondIf::eval() { if (evalBlock(then_block, block_inputs, getOutputs())) { - WARNING("OpCondIf: Fail to evaluate then branch block %s", attribute->then_branch().c_str()); + WARNING("OpCondIf: Fail to evaluate then branch block %s", attribute->then_graph().c_str()); return 1; } } @@ -284,7 +284,7 @@ int OpCondIf::eval() { if (evalBlock(else_block, block_inputs, getOutputs())) { - WARNING("OpCondIf: Fail to evaluate else branch block %s", attribute->else_branch().c_str()); + WARNING("OpCondIf: Fail to evaluate else branch block %s", attribute->else_graph().c_str()); return 1; } } @@ -327,11 +327,11 @@ int OpWhileLoop::checkTensorAttributes() return 1; } - auto cond_region = tsh->GetRegionByName(attribute->cond_branch()); - auto body_region = tsh->GetRegionByName(attribute->body_branch()); + auto cond_region = tsh->GetRegionByName(attribute->cond_graph()); + auto body_region = tsh->GetRegionByName(attribute->body_graph()); if (cond_region && body_region) { - // new serialization: then_branch and else_branch point to regions + // new serialization: then_graph and else_graph point to regions cond_block = cond_region->GetBlocks().front(); body_block = body_region->GetBlocks().front(); } @@ -339,12 +339,12 @@ int OpWhileLoop::checkTensorAttributes() { auto region_name = getParentSGT()->getRegionName(); auto curr_region = tsh->GetRegionByName(region_name); - cond_block = curr_region->GetBlockByName(attribute->cond_branch()); - body_block = curr_region->GetBlockByName(attribute->body_branch()); + cond_block = curr_region->GetBlockByName(attribute->cond_graph()); + body_block = curr_region->GetBlockByName(attribute->body_graph()); } - ERROR_IF(!cond_block, "OpWhileLoop: fail to resolve cond_branch %s", attribute->cond_branch().c_str()); - ERROR_IF(!body_block, "OpWhileLoop: fail to resolve body_branch %s", attribute->body_branch().c_str()); + ERROR_IF(!cond_block, "OpWhileLoop: fail to resolve cond_graph %s", attribute->cond_graph().c_str()); + ERROR_IF(!body_block, "OpWhileLoop: fail to resolve body_graph %s", attribute->body_graph().c_str()); // Make sure operator input/output matches block input/output int32_t num_block_tensor = getInputs().size(); @@ -418,7 +418,7 @@ int OpWhileLoop::eval() { if (evalBlock(cond_block, getInputs(), cond_block_outputs)) { - WARNING("OpWhileLoop: Fail to evaluate cond block %s", attribute->cond_branch().c_str()); + WARNING("OpWhileLoop: Fail to evaluate cond block %s", attribute->cond_graph().c_str()); return 1; } bool cond_val = cond_output_ctensor.getTensor()(0); @@ -428,7 +428,7 @@ int OpWhileLoop::eval() { if (evalBlock(body_block, getInputs(), getOutputs())) { - WARNING("OpWhileLoop: Fail to evaluate body block %s", attribute->body_branch().c_str()); + WARNING("OpWhileLoop: Fail to evaluate body block %s", attribute->body_graph().c_str()); return 1; } |