diff options
Diffstat (limited to 'reference_model')
-rw-r--r-- | reference_model/src/model_runner_impl.cc | 6 | ||||
-rw-r--r-- | reference_model/src/operators.cc | 2 | ||||
-rw-r--r-- | reference_model/src/ops/control_flow.cc | 7 | ||||
-rw-r--r-- | reference_model/src/subgraph_traverser.cc | 18 |
4 files changed, 29 insertions, 4 deletions
diff --git a/reference_model/src/model_runner_impl.cc b/reference_model/src/model_runner_impl.cc index 8089a1a..31e100a 100644 --- a/reference_model/src/model_runner_impl.cc +++ b/reference_model/src/model_runner_impl.cc @@ -264,6 +264,12 @@ int ModelRunnerImpl::getOutput(std::string output_name, uint8_t* raw_ptr, size_t status = tensor->writeToVector(ArrayProxy(elements, typed_ptr)); break; } + case TOSA_REF_TYPE_BOOL: { + auto typed_ptr = reinterpret_cast<unsigned char*>(raw_ptr); + const int elements = size / sizeof(unsigned char); + status = tensor->writeToVector(ArrayProxy(elements, typed_ptr)); + break; + } default: status = 1; } diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc index a627322..a070326 100644 --- a/reference_model/src/operators.cc +++ b/reference_model/src/operators.cc @@ -50,6 +50,8 @@ tosa::DType translate_client_datatype(tosa_datatype_t type) return tosa::DType::DType_FP16; case tosa_datatype_fp32_t: return tosa::DType::DType_FP32; + case tosa_datatype_bool_t: + return tosa::DType::DType_BOOL; default: return tosa::DType::DType_UNKNOWN; } diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc index 03ad6c6..a0e1fc2 100644 --- a/reference_model/src/ops/control_flow.cc +++ b/reference_model/src/ops/control_flow.cc @@ -172,6 +172,8 @@ OpCondIf::~OpCondIf() int OpCondIf::checkTensorAttributes() { + ERROR_IF(!tsh, "OpCondIf: tosa serialization handler must not be null"); + ERROR_IF(getInputs().size() < 1, "OpCondIf: must have at least 1 operand"); ERROR_IF(inputs[0]->getDtype() != TOSA_REF_TYPE_BOOL || inputs[0]->getRank() != 0, @@ -309,6 +311,11 @@ OpWhileLoop::~OpWhileLoop() int OpWhileLoop::checkTensorAttributes() { + if (!tsh) { + WARNING("OpWhileLoop: tosa serialization handler must not be null"); + return 1; + } + if (getInputs().size() <= 0) { WARNING("OpWhileLoop: must have at least 1 operands"); diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index 4508291..c02581f 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -123,16 +123,26 @@ int SubgraphTraverser::initializeGraph() std::vector<TosaSerializationTensor*> ser_tensor_vec; // Get all the serialized tensors from TosaSerializationHandler. - for (auto region : tsh->GetRegions()) + if (tsh) { - for (auto block : region->GetBlocks()) + for (auto region : tsh->GetRegions()) { - for (auto ser_tensor : block->GetTensors()) + for (auto block : region->GetBlocks()) { - ser_tensor_vec.push_back(ser_tensor); + for (auto ser_tensor : block->GetTensors()) + { + ser_tensor_vec.push_back(ser_tensor); + } } } } + else + { + for (auto ser_tensor : block->GetTensors()) + { + ser_tensor_vec.push_back(ser_tensor); + } + } std::vector<GraphNode*> non_const_node_vec; for (auto op : block->GetOperators()) |