diff options
Diffstat (limited to 'reference_model/src/ops/control_flow.cc')
-rw-r--r-- | reference_model/src/ops/control_flow.cc | 25 |
1 files changed, 12 insertions, 13 deletions
diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc index 7105caf..942652d 100644 --- a/reference_model/src/ops/control_flow.cc +++ b/reference_model/src/ops/control_flow.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,7 +15,6 @@ #include "control_flow.h" #include "subgraph_traverser.h" - using namespace TosaReference; using namespace Eigen; using namespace tosa; @@ -37,7 +36,7 @@ int OpControlFlow::evalBlock(TosaSerializationBasicBlock* block, DEBUG_MED(OP, "Evaluating block %s", block_name.c_str()); - SubgraphTraverser block_sgt(block, tsh); + SubgraphTraverser block_sgt(block, tsh, this->parent_sgt); ERROR_IF(block_sgt.initializeGraph(), "evalBlock(): Unable to initialize graph traverser for %s", block_name.c_str()); @@ -182,8 +181,10 @@ int OpCondIf::checkTensorAttributes() cond = dynamic_cast<TosaReference::Tensor0<bool>*>(inputs[0]); ASSERT_MEM(cond); - then_block = tsh->GetBlockByName(attribute->then_branch()); - else_block = tsh->GetBlockByName(attribute->else_branch()); + auto region_name = getParentSGT()->getRegionName(); + auto curr_region = tsh->GetRegionByName(region_name); + then_block = curr_region->GetBlockByName(attribute->then_branch()); + else_block = curr_region->GetBlockByName(attribute->else_branch()); ERROR_IF(!then_block, "OpCondIf: fail to resolve then_branch %s", attribute->then_branch().c_str()); @@ -193,6 +194,7 @@ int OpCondIf::checkTensorAttributes() // Skip the first rank 0 bool tensor on input list int32_t num_input_tensor = getInputs().size() - 1; int32_t num_output_tensor = getOutputs().size(); + ERROR_IF((int32_t)then_block->GetInputs().size() != num_input_tensor, "OpCondIf: then_block has unexpected number of input"); ERROR_IF((int32_t)else_block->GetInputs().size() != num_input_tensor, @@ -307,8 +309,10 @@ int OpWhileLoop::checkTensorAttributes() return 1; } - cond_block = tsh->GetBlockByName(attribute->cond_branch()); - body_block = tsh->GetBlockByName(attribute->body_branch()); + auto region_name = getParentSGT()->getRegionName(); + auto curr_region = tsh->GetRegionByName(region_name); + cond_block = curr_region->GetBlockByName(attribute->cond_branch()); + body_block = curr_region->GetBlockByName(attribute->body_branch()); ERROR_IF(!cond_block, "OpWhileLoop: fail to resolve cond_branch %s", attribute->cond_branch().c_str()); ERROR_IF(!body_block, "OpWhileLoop: fail to resolve body_branch %s", attribute->body_branch().c_str()); @@ -403,12 +407,7 @@ int OpWhileLoop::eval() // assigning output tensors value back to input tensors value for next iteration for (size_t i = 0; i < num_input_output; i++) { - if (getInputs()[i]->copyValueFrom(getOutputs()[i])) - { - WARNING("Fail to copy tensor value %s -> %s", getOutputs()[i]->getName().c_str(), - getInputs()[i]->getName().c_str()); - return 1; - } + getInputs()[i] = getOutputs()[i]; } } else |