diff options
author | Tai Ly <tai.ly@arm.com> | 2023-12-18 20:40:24 +0000 |
---|---|---|
committer | Tai Ly <tai.ly@arm.com> | 2024-01-18 23:50:04 +0000 |
commit | 8690a0873fac28acccbb0acb15c16e8337e14df1 (patch) | |
tree | a13d5e195d8b7becffc23da98fde7449e91c96e4 /reference_model/src/subgraph_traverser.cc | |
parent | 9f5febe05901bfbd3919ef17f2caea8087cd9ccf (diff) | |
download | reference_model-8690a0873fac28acccbb0acb15c16e8337e14df1.tar.gz |
[reference model] Add shape operators
- fixed up reshape conformance tests to use shape input instead of attribute
- fixed up tile conformance tests to use shape input instead of attribute
- fixed output and output rank of dim op
- allow rank 0 and rank 1 tensors for tosa.shape values (for shape = {})
- added initialization of rank 0 const_shape tensors (for shape = {})
- Update conformance tests to use new rescale attributes
Signed-off-by: Tai Ly <tai.ly@arm.com>
Signed-off-by: Won Jeon <won.jeon@arm.com>
Change-Id: I6cce0d2a9ab066fe20a2abf9d2cfde3eb3d8c18b
Diffstat (limited to 'reference_model/src/subgraph_traverser.cc')
-rw-r--r-- | reference_model/src/subgraph_traverser.cc | 19 |
1 files changed, 17 insertions, 2 deletions
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index 745213e..fae0b30 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -510,6 +510,12 @@ int SubgraphTraverser::allocateTensor(std::string name) FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str()); } + // set valid for constant tensors: + if ((ts->GetShape().empty() && ts->GetDtype() == DType_SHAPE)) + { + // corner case: const_shape {} has no data + tensor->setIsValid(); + } if (!ts->GetData().empty()) { if (ts->GetVariable() && g_func_config.initialize_variable_tensor_from_numpy) @@ -545,13 +551,18 @@ int SubgraphTraverser::allocateTensor(std::string name) tensor->setTensorValueInt32(i32_data.size(), i32_data.data()); } break; - case DType_INT48: - case DType_SHAPE: { + case DType_INT48: { std::vector<int64_t> i64_data; TosaSerializationHandler::ConvertU8toI48(ts->GetData(), tensor->getElementCount(), i64_data); tensor->setTensorValueInt64(i64_data.size(), i64_data.data()); } break; + case DType_SHAPE: { + std::vector<int64_t> i64_data; + TosaSerializationHandler::ConvertU8toI64(ts->GetData(), tensor->getElementCount(), i64_data); + tensor->setTensorValueInt64(i64_data.size(), i64_data.data()); + } + break; case DType_FP16: { // Interpret f16 data as float std::vector<float> f16_data; @@ -617,6 +628,10 @@ int SubgraphTraverser::allocateTensor(std::string name) EnumNameDType(ts->GetDtype())); } tensor->setIsValid(); + } + + if (tensor->getIsValid()) + { // Push ready consumers to the next node list for (auto gn : tensor->getConsumers()) { |