aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/control_flow.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/control_flow.cc')
-rw-r--r--reference_model/src/ops/control_flow.cc25
1 files changed, 12 insertions, 13 deletions
diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc
index 7105caf..942652d 100644
--- a/reference_model/src/ops/control_flow.cc
+++ b/reference_model/src/ops/control_flow.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -15,7 +15,6 @@
#include "control_flow.h"
#include "subgraph_traverser.h"
-
using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
@@ -37,7 +36,7 @@ int OpControlFlow::evalBlock(TosaSerializationBasicBlock* block,
DEBUG_MED(OP, "Evaluating block %s", block_name.c_str());
- SubgraphTraverser block_sgt(block, tsh);
+ SubgraphTraverser block_sgt(block, tsh, this->parent_sgt);
ERROR_IF(block_sgt.initializeGraph(), "evalBlock(): Unable to initialize graph traverser for %s",
block_name.c_str());
@@ -182,8 +181,10 @@ int OpCondIf::checkTensorAttributes()
cond = dynamic_cast<TosaReference::Tensor0<bool>*>(inputs[0]);
ASSERT_MEM(cond);
- then_block = tsh->GetBlockByName(attribute->then_branch());
- else_block = tsh->GetBlockByName(attribute->else_branch());
+ auto region_name = getParentSGT()->getRegionName();
+ auto curr_region = tsh->GetRegionByName(region_name);
+ then_block = curr_region->GetBlockByName(attribute->then_branch());
+ else_block = curr_region->GetBlockByName(attribute->else_branch());
ERROR_IF(!then_block, "OpCondIf: fail to resolve then_branch %s", attribute->then_branch().c_str());
@@ -193,6 +194,7 @@ int OpCondIf::checkTensorAttributes()
// Skip the first rank 0 bool tensor on input list
int32_t num_input_tensor = getInputs().size() - 1;
int32_t num_output_tensor = getOutputs().size();
+
ERROR_IF((int32_t)then_block->GetInputs().size() != num_input_tensor,
"OpCondIf: then_block has unexpected number of input");
ERROR_IF((int32_t)else_block->GetInputs().size() != num_input_tensor,
@@ -307,8 +309,10 @@ int OpWhileLoop::checkTensorAttributes()
return 1;
}
- cond_block = tsh->GetBlockByName(attribute->cond_branch());
- body_block = tsh->GetBlockByName(attribute->body_branch());
+ auto region_name = getParentSGT()->getRegionName();
+ auto curr_region = tsh->GetRegionByName(region_name);
+ cond_block = curr_region->GetBlockByName(attribute->cond_branch());
+ body_block = curr_region->GetBlockByName(attribute->body_branch());
ERROR_IF(!cond_block, "OpWhileLoop: fail to resolve cond_branch %s", attribute->cond_branch().c_str());
ERROR_IF(!body_block, "OpWhileLoop: fail to resolve body_branch %s", attribute->body_branch().c_str());
@@ -403,12 +407,7 @@ int OpWhileLoop::eval()
// assigning output tensors value back to input tensors value for next iteration
for (size_t i = 0; i < num_input_output; i++)
{
- if (getInputs()[i]->copyValueFrom(getOutputs()[i]))
- {
- WARNING("Fail to copy tensor value %s -> %s", getOutputs()[i]->getName().c_str(),
- getInputs()[i]->getName().c_str());
- return 1;
- }
+ getInputs()[i] = getOutputs()[i];
}
}
else