aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJiacheng Liang <jiacheng.liang@arm.com>2023-05-17 16:49:44 +0100
committerEric Kunze <eric.kunze@arm.com>2023-06-02 21:29:19 +0000
commiteb52cc18b342d6329322f84b671eab4450e663fd (patch)
treec6bc5b6fbd064b7fe0a7ae0c0c2be4d9939bd813
parent56a3a06c260714703cb531bac3417ed37aebe6ce (diff)
downloadreference_model-eb52cc18b342d6329322f84b671eab4450e663fd.tar.gz
Add support for boolean outputs in model runner
Comparison operators produce boolean outputs, which need to be written into client data Allow subgraph traverser to use main block to look for tensors when serialization handler is missing Signed-off-by: Jiacheng Liang <jiacheng.liang@arm.com> Change-Id: I6f9af470185541fa6466b3f7786c48f1555fa6f6
-rw-r--r--reference_model/src/model_runner_impl.cc6
-rw-r--r--reference_model/src/operators.cc2
-rw-r--r--reference_model/src/ops/control_flow.cc7
-rw-r--r--reference_model/src/subgraph_traverser.cc18
4 files changed, 29 insertions, 4 deletions
diff --git a/reference_model/src/model_runner_impl.cc b/reference_model/src/model_runner_impl.cc
index 8089a1a..31e100a 100644
--- a/reference_model/src/model_runner_impl.cc
+++ b/reference_model/src/model_runner_impl.cc
@@ -264,6 +264,12 @@ int ModelRunnerImpl::getOutput(std::string output_name, uint8_t* raw_ptr, size_t
status = tensor->writeToVector(ArrayProxy(elements, typed_ptr));
break;
}
+ case TOSA_REF_TYPE_BOOL: {
+ auto typed_ptr = reinterpret_cast<unsigned char*>(raw_ptr);
+ const int elements = size / sizeof(unsigned char);
+ status = tensor->writeToVector(ArrayProxy(elements, typed_ptr));
+ break;
+ }
default:
status = 1;
}
diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc
index a627322..a070326 100644
--- a/reference_model/src/operators.cc
+++ b/reference_model/src/operators.cc
@@ -50,6 +50,8 @@ tosa::DType translate_client_datatype(tosa_datatype_t type)
return tosa::DType::DType_FP16;
case tosa_datatype_fp32_t:
return tosa::DType::DType_FP32;
+ case tosa_datatype_bool_t:
+ return tosa::DType::DType_BOOL;
default:
return tosa::DType::DType_UNKNOWN;
}
diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc
index 03ad6c6..a0e1fc2 100644
--- a/reference_model/src/ops/control_flow.cc
+++ b/reference_model/src/ops/control_flow.cc
@@ -172,6 +172,8 @@ OpCondIf::~OpCondIf()
int OpCondIf::checkTensorAttributes()
{
+ ERROR_IF(!tsh, "OpCondIf: tosa serialization handler must not be null");
+
ERROR_IF(getInputs().size() < 1, "OpCondIf: must have at least 1 operand");
ERROR_IF(inputs[0]->getDtype() != TOSA_REF_TYPE_BOOL || inputs[0]->getRank() != 0,
@@ -309,6 +311,11 @@ OpWhileLoop::~OpWhileLoop()
int OpWhileLoop::checkTensorAttributes()
{
+ if (!tsh) {
+ WARNING("OpWhileLoop: tosa serialization handler must not be null");
+ return 1;
+ }
+
if (getInputs().size() <= 0)
{
WARNING("OpWhileLoop: must have at least 1 operands");
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index 4508291..c02581f 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -123,16 +123,26 @@ int SubgraphTraverser::initializeGraph()
std::vector<TosaSerializationTensor*> ser_tensor_vec;
// Get all the serialized tensors from TosaSerializationHandler.
- for (auto region : tsh->GetRegions())
+ if (tsh)
{
- for (auto block : region->GetBlocks())
+ for (auto region : tsh->GetRegions())
{
- for (auto ser_tensor : block->GetTensors())
+ for (auto block : region->GetBlocks())
{
- ser_tensor_vec.push_back(ser_tensor);
+ for (auto ser_tensor : block->GetTensors())
+ {
+ ser_tensor_vec.push_back(ser_tensor);
+ }
}
}
}
+ else
+ {
+ for (auto ser_tensor : block->GetTensors())
+ {
+ ser_tensor_vec.push_back(ser_tensor);
+ }
+ }
std::vector<GraphNode*> non_const_node_vec;
for (auto op : block->GetOperators())