aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--reference_model/src/model_runner_impl.cc6
-rw-r--r--reference_model/src/operators.cc2
-rw-r--r--reference_model/src/ops/control_flow.cc7
-rw-r--r--reference_model/src/subgraph_traverser.cc18
4 files changed, 29 insertions, 4 deletions
diff --git a/reference_model/src/model_runner_impl.cc b/reference_model/src/model_runner_impl.cc
index 8089a1a..31e100a 100644
--- a/reference_model/src/model_runner_impl.cc
+++ b/reference_model/src/model_runner_impl.cc
@@ -264,6 +264,12 @@ int ModelRunnerImpl::getOutput(std::string output_name, uint8_t* raw_ptr, size_t
status = tensor->writeToVector(ArrayProxy(elements, typed_ptr));
break;
}
+ case TOSA_REF_TYPE_BOOL: {
+ auto typed_ptr = reinterpret_cast<unsigned char*>(raw_ptr);
+ const int elements = size / sizeof(unsigned char);
+ status = tensor->writeToVector(ArrayProxy(elements, typed_ptr));
+ break;
+ }
default:
status = 1;
}
diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc
index a627322..a070326 100644
--- a/reference_model/src/operators.cc
+++ b/reference_model/src/operators.cc
@@ -50,6 +50,8 @@ tosa::DType translate_client_datatype(tosa_datatype_t type)
return tosa::DType::DType_FP16;
case tosa_datatype_fp32_t:
return tosa::DType::DType_FP32;
+ case tosa_datatype_bool_t:
+ return tosa::DType::DType_BOOL;
default:
return tosa::DType::DType_UNKNOWN;
}
diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc
index 03ad6c6..a0e1fc2 100644
--- a/reference_model/src/ops/control_flow.cc
+++ b/reference_model/src/ops/control_flow.cc
@@ -172,6 +172,8 @@ OpCondIf::~OpCondIf()
int OpCondIf::checkTensorAttributes()
{
+ ERROR_IF(!tsh, "OpCondIf: tosa serialization handler must not be null");
+
ERROR_IF(getInputs().size() < 1, "OpCondIf: must have at least 1 operand");
ERROR_IF(inputs[0]->getDtype() != TOSA_REF_TYPE_BOOL || inputs[0]->getRank() != 0,
@@ -309,6 +311,11 @@ OpWhileLoop::~OpWhileLoop()
int OpWhileLoop::checkTensorAttributes()
{
+ if (!tsh) {
+ WARNING("OpWhileLoop: tosa serialization handler must not be null");
+ return 1;
+ }
+
if (getInputs().size() <= 0)
{
WARNING("OpWhileLoop: must have at least 1 operands");
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index 4508291..c02581f 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -123,16 +123,26 @@ int SubgraphTraverser::initializeGraph()
std::vector<TosaSerializationTensor*> ser_tensor_vec;
// Get all the serialized tensors from TosaSerializationHandler.
- for (auto region : tsh->GetRegions())
+ if (tsh)
{
- for (auto block : region->GetBlocks())
+ for (auto region : tsh->GetRegions())
{
- for (auto ser_tensor : block->GetTensors())
+ for (auto block : region->GetBlocks())
{
- ser_tensor_vec.push_back(ser_tensor);
+ for (auto ser_tensor : block->GetTensors())
+ {
+ ser_tensor_vec.push_back(ser_tensor);
+ }
}
}
}
+ else
+ {
+ for (auto ser_tensor : block->GetTensors())
+ {
+ ser_tensor_vec.push_back(ser_tensor);
+ }
+ }
std::vector<GraphNode*> non_const_node_vec;
for (auto op : block->GetOperators())